diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9fe17bc --- /dev/null +++ b/.gitignore @@ -0,0 +1,129 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..6e83b10 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Shengpu Tang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..16a61f2 --- /dev/null +++ b/README.md @@ -0,0 +1,41 @@ +# Counterfactual-Augmented Importance Sampling for Semi-Offline Policy Evaluation + +This repository contains the source code for replicating all experiments in the NeurIPS 2023 paper, "Counterfactual-Augmented Importance Sampling for Semi-Offline Policy Evaluation". + +Repository content: + +- `synthetic/` contains code for the experiments on the toy bandit problems. +- `sepsisSim/` contains code for the experiments on the sepsis simulator. + +If you use this code in your research, please cite the following publication: +``` +@inproceedings{tang2023counterfactual, + title={Counterfactual-Augmented Importance Sampling for Semi-Offline Policy Evaluation}, + author={Tang, Shengpu and Wiens, Jenna}, + booktitle={Advances in Neural Information Processing Systems}, + year={2023}, + url={https://openreview.net/forum?id=dsH244r9fA} +} +``` + +## Synthetic domains - Bandits (Sec 5.1 & Appx E.1) +- `bandit_compare-2state.ipynb`: Table 1, Appx E Table 3 +- `bandit_compare-1state.ipynb`: Appx E Table 4 +- `bandit_sweepW.ipynb`: Appx E Fig 7 (varying weights) +- `bandit_sweepPcannot.ipynb`: Appx E Fig 8 (varying percent annotated and imputation) + +## Healthcare domain - Sepsis simulator (Sec 5.2 & Appx E.2) +- Simulator based on publicly available code at https://github.com/clinicalml/gumbel-max-scm/tree/sim-v2 +- Experiment code structure are inspired by https://github.com/MLD3/OfflineRL_ModelSelection and https://github.com/MLD3/OfflineRL_FactoredActions +- The preparation steps are in `data-prep/`, which include the simulator source code as well as several notebooks for dataset generation. The output is saved to `data/` (ground-truth MDP parameters, ground-truth optimal policy, and optimal value functions) and `datagen/` (offline datasets). +- The code for the main experiments is in `experiments/`. + - `0-prepopulate_annotations.ipynb` preprocesses the data so that all counterfactual annotations are pre-populated. This step is run only once (instead of repeated for each experiment) as it takes a long time. The annotations are later replaced with different values or removed depending on the experiment. The output is saved to `results/vaso_eps_0_1-annotOpt_df_seed2_aug_step.pkl`. + - Use `commands.sh` to run the experiments. + - Run the following notebooks to generate tables and figures used in the paper: + - `results.ipynb`: Table 2, Appx E Table 5 top, Appx E Fig 9 + - `results-OIS-WIS.ipynb`: Appx E Table 5 bottom + - `plots-analyses.ipynb`: Fig 5 left, Appx E Fig 10 + - `plots-noisy[-v2].ipynb`: Fig 5 center, Appx E Fig 11 left¢er + - `plots-missing.ipynb`: Fig 5 right, Appx E Fig 11 right + - `plots--legend.ipynb`: figure legend used in Fig 5 and Fig 11 + - `fig/` contains the main figures used in the paper. diff --git a/sepsisSim/data-prep/0-compute-MDP-parameters.ipynb b/sepsisSim/data-prep/0-compute-MDP-parameters.ipynb new file mode 100644 index 0000000..ff64f4c --- /dev/null +++ b/sepsisSim/data-prep/0-compute-MDP-parameters.ipynb @@ -0,0 +1,1466 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Compute the exact transition matrix for the sepsis simulator v2\n", + "https://github.com/clinicalml/gumbel-max-scm/tree/854229e039b52f10257ad5460fa79d34f0452b27/sepsisSimDiabetes" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Apply transitions in the following order:\n", + "# abx on/off\n", + "# vent on/off\n", + "# vaso on/off\n", + "# hr fluctuate\n", + "# sbp fluctuate\n", + "# o2 fluctuate\n", + "# glu fluctuate" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from sepsisSimDiabetes.State import State\n", + "from sepsisSimDiabetes.Action import Action\n", + "from sepsisSimDiabetes.MDP import MDP\n", + "import itertools\n", + "import joblib" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# state_categs\n", + "# [hr, sbp, o2, glu, abx, vent, vaso, diab]\n", + "\n", + "state_variable_values = {\n", + " 'hr': [0,1,2], \n", + " 'sbp': [0,1,2], \n", + " 'o2': [0,1], \n", + " 'glu': [0,1,2,3,4], \n", + " 'abx': [0,1], \n", + " 'vaso': [0,1], \n", + " 'vent': [0,1], \n", + " 'diab': [0,1],\n", + "}\n", + "state_variables = list(state_variable_values.keys())" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "nS = 720\n", + "nA = 8" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Reward Matrix (A,S,S) (S,A)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "303 non-terminal states\n", + "416 death states\n", + "1 discharge states\n" + ] + } + ], + "source": [ + "dummy_pol = np.ones((nS, nA)) / nA\n", + "reward_per_state = np.zeros((nS))\n", + "for s in range(nS):\n", + " this_mdp = MDP(init_state_idx=s, policy_array=dummy_pol, p_diabetes=0)\n", + " r = this_mdp.calculateReward()\n", + " reward_per_state[s] = r\n", + "\n", + "print((reward_per_state == 0).sum(), 'non-terminal states')\n", + "print((reward_per_state == -1).sum(), 'death states')\n", + "print((reward_per_state == 1).sum(), 'discharge states')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "reward_matrix_ASS = np.zeros((nA, nS*2, nS*2))\n", + "for s in range(nS):\n", + " reward_matrix_ASS[:, :nS, s] = reward_per_state[s]\n", + " reward_matrix_ASS[:, nS:, nS+s] = reward_per_state[s]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# Assign reward for the transition from death/disch\n", + "reward_matrix_absorbing_SA = np.zeros((nS*2+2, nA))\n", + "for s in range(nS):\n", + " if reward_per_state[s] == -1:\n", + " reward_matrix_absorbing_SA[s, :] = -1\n", + " reward_matrix_absorbing_SA[nS+s, :] = -1\n", + " elif reward_per_state[s] == 1:\n", + " reward_matrix_absorbing_SA[s, :] = 1\n", + " reward_matrix_absorbing_SA[nS+s, :] = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "reward_matrix_absorbing_ASS = np.zeros((nA, nS*2+2, nS*2+2))\n", + "\n", + "# Assign reward for the transition from death/disch leading to the corresponding absorbing state\n", + "reward_matrix_absorbing_ASS[..., -2] = -1\n", + "reward_matrix_absorbing_ASS[..., -1] = 1\n", + "\n", + "# No reward once in aborbing state\n", + "reward_matrix_absorbing_ASS[..., -2, -2] = 0 \n", + "reward_matrix_absorbing_ASS[..., -1, -1] = 0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Treatments" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Antibiotics" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# abx indicator\n", + "abx_on = np.zeros((nS, nS))\n", + "abx_off = np.zeros((nS, nS))\n", + "for (hr, sbp, o2, glu, vent, vaso, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['abx']]\n", + "):\n", + " s0 = State(state_categs=[hr, sbp, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[hr, sbp, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " abx_on[s0, s1] = 1\n", + " abx_on[s1, s1] = 1\n", + " abx_off[s1, s0] = 1\n", + " abx_off[s0, s0] = 1\n", + "\n", + "assert np.isclose(abx_on.sum(axis=1), 1).all()\n", + "assert np.isclose(abx_off.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# abx affects hr: high->normal wp 0.5\n", + "hr_H2N_wp05 = np.zeros((nS, nS))\n", + "for (sbp, o2, glu, abx, vent, vaso, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['hr']]\n", + "):\n", + " s0 = State(state_categs=[0, sbp, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[1, sbp, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2 = State(state_categs=[2, sbp, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " hr_H2N_wp05[s2, s2] = 0.5\n", + " hr_H2N_wp05[s2, s1] = 0.5\n", + " hr_H2N_wp05[s1, s1] = 1\n", + " hr_H2N_wp05[s0, s0] = 1\n", + "\n", + "assert np.isclose(hr_H2N_wp05.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# abx affects sbp: high->normal wp 0.5\n", + "sbp_H2N_wp05 = np.zeros((nS, nS))\n", + "for (hr, o2, glu, abx, vent, vaso, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['sbp']]\n", + "):\n", + " s0 = State(state_categs=[hr, 0, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[hr, 1, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2 = State(state_categs=[hr, 2, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " sbp_H2N_wp05[s2, s2] = 0.5\n", + " sbp_H2N_wp05[s2, s1] = 0.5\n", + " sbp_H2N_wp05[s1, s1] = 1\n", + " sbp_H2N_wp05[s0, s0] = 1\n", + "\n", + "assert np.isclose(sbp_H2N_wp05.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# abx withdrawn affects hr: normal->high wp 0.1\n", + "hr_N2H_wp01 = np.zeros((nS, nS))\n", + "for (sbp, o2, glu, abx, vent, vaso, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['hr']]\n", + "):\n", + " s0 = State(state_categs=[0, sbp, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[1, sbp, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2 = State(state_categs=[2, sbp, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s0_0 = State(state_categs=[0, sbp, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1_0 = State(state_categs=[1, sbp, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2_0 = State(state_categs=[2, sbp, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s0_1 = State(state_categs=[0, sbp, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1_1 = State(state_categs=[1, sbp, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2_1 = State(state_categs=[2, sbp, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " hr_N2H_wp01[s2, s2] = 1\n", + " hr_N2H_wp01[s1_1, s1_1] = 0.9\n", + " hr_N2H_wp01[s1_1, s2_1] = 0.1\n", + " hr_N2H_wp01[s1_0, s1_0] = 1\n", + " hr_N2H_wp01[s0, s0] = 1\n", + "\n", + "assert np.isclose(hr_N2H_wp01.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# abx withdrawn affects sbp: normal->high wp 0.1\n", + "sbp_N2H_wp01 = np.zeros((nS, nS))\n", + "for (hr, o2, glu, abx, vent, vaso, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['sbp']]\n", + "):\n", + " s0 = State(state_categs=[hr, 0, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[hr, 1, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2 = State(state_categs=[hr, 2, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s0_0 = State(state_categs=[hr, 0, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1_0 = State(state_categs=[hr, 1, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2_0 = State(state_categs=[hr, 2, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s0_1 = State(state_categs=[hr, 0, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1_1 = State(state_categs=[hr, 1, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2_1 = State(state_categs=[hr, 2, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " sbp_N2H_wp01[s2, s2] = 1\n", + " sbp_N2H_wp01[s1_1, s1_1] = 0.9\n", + " sbp_N2H_wp01[s1_1, s2_1] = 0.1\n", + " sbp_N2H_wp01[s1_0, s1_0] = 1\n", + " sbp_N2H_wp01[s0, s0] = 1\n", + "\n", + "assert np.isclose(sbp_N2H_wp01.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "# antibiotics on\n", + "# hr: hi -> normal w.p. 0.5\n", + "# sbp: hi -> normal w.p. 0.5\n", + "antibiotics_on_ = hr_H2N_wp05 @ sbp_H2N_wp05 @ abx_on\n", + "antibiotics_on = np.block([[antibiotics_on_, np.zeros((nS, nS))], [np.zeros((nS, nS)), antibiotics_on_]])\n", + "assert np.isclose(antibiotics_on.sum(axis=1), 1).all()\n", + "\n", + "# antibiotics off\n", + "# if antibiotics was on\n", + "# hr: normal -> hi w.p. 0.1\n", + "# sbp: normal -> hi w.p. 0.1\n", + "antibiotics_off_ = hr_N2H_wp01 @ sbp_N2H_wp01 @ abx_off\n", + "antibiotics_off = np.block([[antibiotics_off_, np.zeros((nS, nS))], [np.zeros((nS, nS)), antibiotics_off_]])\n", + "assert np.isclose(antibiotics_off.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Ventilation" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# vent indicator\n", + "vent_on = np.zeros((nS, nS))\n", + "vent_off = np.zeros((nS, nS))\n", + "for (hr, sbp, o2, glu, abx, vaso, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['vent']]\n", + "):\n", + " s0 = State(state_categs=[hr, sbp, o2, glu, abx, vaso, 0], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[hr, sbp, o2, glu, abx, vaso, 1], diabetic_idx=0).get_state_idx()\n", + " vent_on[s0, s1] = 1\n", + " vent_on[s1, s1] = 1\n", + " vent_off[s1, s0] = 1\n", + " vent_off[s0, s0] = 1\n", + "\n", + "assert np.isclose(vent_on.sum(axis=1), 1).all()\n", + "assert np.isclose(vent_off.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# vent affects o2: low->normal wp 0.7\n", + "o2_L2N_wp07 = np.zeros((nS, nS))\n", + "for (hr, sbp, glu, abx, vent, vaso, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['o2']]\n", + "):\n", + " s0 = State(state_categs=[hr, sbp, 0, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[hr, sbp, 1, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " o2_L2N_wp07[s0, s0] = 0.3\n", + " o2_L2N_wp07[s0, s1] = 0.7\n", + " o2_L2N_wp07[s1, s1] = 1\n", + "\n", + "assert np.isclose(o2_L2N_wp07.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# vent withdrawn affects o2: normal->low wp 0.1\n", + "o2_N2L_wp01 = np.zeros((nS, nS))\n", + "for (hr, sbp, glu, abx, vent, vaso, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['o2']]\n", + "):\n", + " s0 = State(state_categs=[hr, sbp, 0, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[hr, sbp, 1, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s0_0 = State(state_categs=[hr, sbp, 0, glu, abx, vaso, 0], diabetic_idx=0).get_state_idx()\n", + " s1_0 = State(state_categs=[hr, sbp, 1, glu, abx, vaso, 0], diabetic_idx=0).get_state_idx()\n", + " s0_1 = State(state_categs=[hr, sbp, 0, glu, abx, vaso, 1], diabetic_idx=0).get_state_idx()\n", + " s1_1 = State(state_categs=[hr, sbp, 1, glu, abx, vaso, 1], diabetic_idx=0).get_state_idx()\n", + " o2_N2L_wp01[s0_0, s0_0] = 1\n", + " o2_N2L_wp01[s0_1, s0_1] = 1\n", + " o2_N2L_wp01[s1_0, s1_0] = 1\n", + " o2_N2L_wp01[s1_1, s0_1] = 0.1\n", + " o2_N2L_wp01[s1_1, s1_1] = 0.9\n", + "\n", + "assert np.isclose(o2_N2L_wp01.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "ventilation_on_ = o2_L2N_wp07 @ vent_on\n", + "ventilation_off_ = o2_N2L_wp01 @ vent_off\n", + "ventilation_on = np.block([[ventilation_on_, np.zeros((nS, nS))], [np.zeros((nS, nS)), ventilation_on_]])\n", + "ventilation_off = np.block([[ventilation_off_, np.zeros((nS, nS))], [np.zeros((nS, nS)), ventilation_off_]])\n", + "assert np.isclose(ventilation_on.sum(axis=1), 1).all()\n", + "assert np.isclose(ventilation_off.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Vasopressor" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "# vaso indicator\n", + "vaso_on = np.zeros((nS, nS))\n", + "vaso_off = np.zeros((nS, nS))\n", + "for (hr, sbp, o2, glu, abx, vent, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['vaso']]\n", + "):\n", + " s0 = State(state_categs=[hr, sbp, o2, glu, abx, 0, vent], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[hr, sbp, o2, glu, abx, 1, vent], diabetic_idx=0).get_state_idx()\n", + " vaso_on[s0, s1] = 1\n", + " vaso_on[s1, s1] = 1\n", + " vaso_off[s1, s0] = 1\n", + " vaso_off[s0, s0] = 1\n", + "\n", + "assert np.isclose(vaso_on.sum(axis=1), 1).all()\n", + "assert np.isclose(vaso_off.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "# vaso affects sbp (non-diabetic)\n", + "# low->normal wp 0.7, normal->high wp 0.7\n", + "sbp_L2N_N2H_wp07 = np.zeros((nS, nS))\n", + "for (hr, o2, glu, abx, vent, vaso, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['sbp']]\n", + "):\n", + " s0 = State(state_categs=[hr, 0, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[hr, 1, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2 = State(state_categs=[hr, 2, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " sbp_L2N_N2H_wp07[s0, s0] = 0.3\n", + " sbp_L2N_N2H_wp07[s0, s1] = 0.7\n", + " sbp_L2N_N2H_wp07[s1, s1] = 0.3\n", + " sbp_L2N_N2H_wp07[s1, s2] = 0.7\n", + " sbp_L2N_N2H_wp07[s2, s2] = 1\n", + "\n", + "assert np.isclose(sbp_L2N_N2H_wp07.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "# vaso affects sbp (diabetic)\n", + "# low->normal wp 0.5, low->high wp 0.4, normal->high wp 0.9\n", + "sbp_L2N2H = np.zeros((nS, nS))\n", + "for (hr, o2, glu, abx, vent, vaso, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['sbp']]\n", + "):\n", + " s0 = State(state_categs=[hr, 0, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[hr, 1, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2 = State(state_categs=[hr, 2, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " sbp_L2N2H[s0, s0] = 0.1\n", + " sbp_L2N2H[s0, s1] = 0.5\n", + " sbp_L2N2H[s0, s2] = 0.4\n", + " sbp_L2N2H[s1, s1] = 0.1\n", + " sbp_L2N2H[s1, s2] = 0.9\n", + " sbp_L2N2H[s2, s2] = 1\n", + "\n", + "assert np.isclose(sbp_L2N2H.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "# vaso affects glu (diabetic)\n", + "# LL->L, L->N, N->H, H->HH wp 0.5\n", + "glu_raise_by_1 = np.zeros((nS, nS))\n", + "for (hr, sbp, o2, abx, vent, vaso, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['glu']]\n", + "):\n", + " s0 = State(state_categs=[hr, sbp, o2, 0, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[hr, sbp, o2, 1, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2 = State(state_categs=[hr, sbp, o2, 2, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s3 = State(state_categs=[hr, sbp, o2, 3, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s4 = State(state_categs=[hr, sbp, o2, 4, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " glu_raise_by_1[s0, s0] = 0.5\n", + " glu_raise_by_1[s0, s1] = 0.5\n", + " glu_raise_by_1[s1, s1] = 0.5\n", + " glu_raise_by_1[s1, s2] = 0.5\n", + " glu_raise_by_1[s2, s2] = 0.5\n", + " glu_raise_by_1[s2, s3] = 0.5\n", + " glu_raise_by_1[s3, s3] = 0.5\n", + " glu_raise_by_1[s3, s4] = 0.5\n", + " glu_raise_by_1[s4, s4] = 1\n", + "\n", + "assert np.isclose(glu_raise_by_1.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "vasopressor_on = np.block([\n", + " [sbp_L2N_N2H_wp07 @ vaso_on, np.zeros((nS, nS))],\n", + " [np.zeros((nS, nS)), sbp_L2N2H @ glu_raise_by_1 @ vaso_on]\n", + "])\n", + "assert np.isclose(vasopressor_on.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "# vaso withdrawn affects sbp (non-diabetic)\n", + "# N->L, H->N wp 0.1\n", + "sbp_H2N2L_wp01 = np.zeros((nS, nS))\n", + "for (hr, o2, glu, abx, vent, vaso, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['sbp']]\n", + "):\n", + " s0 = State(state_categs=[hr, 0, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[hr, 1, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2 = State(state_categs=[hr, 2, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s0_0 = State(state_categs=[hr, 0, o2, glu, abx, 0, vent], diabetic_idx=0).get_state_idx()\n", + " s1_0 = State(state_categs=[hr, 1, o2, glu, abx, 0, vent], diabetic_idx=0).get_state_idx()\n", + " s2_0 = State(state_categs=[hr, 2, o2, glu, abx, 0, vent], diabetic_idx=0).get_state_idx()\n", + " s0_1 = State(state_categs=[hr, 0, o2, glu, abx, 1, vent], diabetic_idx=0).get_state_idx()\n", + " s1_1 = State(state_categs=[hr, 1, o2, glu, abx, 1, vent], diabetic_idx=0).get_state_idx()\n", + " s2_1 = State(state_categs=[hr, 2, o2, glu, abx, 1, vent], diabetic_idx=0).get_state_idx()\n", + " sbp_H2N2L_wp01[s0_0, s0_0] = 1\n", + " sbp_H2N2L_wp01[s0_1, s0_1] = 1\n", + " sbp_H2N2L_wp01[s1_0, s1_0] = 1\n", + " sbp_H2N2L_wp01[s1_1, s1_1] = 0.9\n", + " sbp_H2N2L_wp01[s1_1, s0_1] = 0.1\n", + " sbp_H2N2L_wp01[s2_0, s2_0] = 1\n", + " sbp_H2N2L_wp01[s2_1, s2_1] = 0.9\n", + " sbp_H2N2L_wp01[s2_1, s1_1] = 0.1\n", + "\n", + "assert np.isclose(sbp_H2N2L_wp01.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "# vaso withdrawn affects sbp (diabetic)\n", + "# N->L, H->N wp 0.05\n", + "sbp_H2N2L_wp005 = np.zeros((nS, nS))\n", + "for (hr, o2, glu, abx, vent, vaso, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['sbp']]\n", + "):\n", + " s0 = State(state_categs=[hr, 0, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[hr, 1, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2 = State(state_categs=[hr, 2, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s0_0 = State(state_categs=[hr, 0, o2, glu, abx, 0, vent], diabetic_idx=0).get_state_idx()\n", + " s1_0 = State(state_categs=[hr, 1, o2, glu, abx, 0, vent], diabetic_idx=0).get_state_idx()\n", + " s2_0 = State(state_categs=[hr, 2, o2, glu, abx, 0, vent], diabetic_idx=0).get_state_idx()\n", + " s0_1 = State(state_categs=[hr, 0, o2, glu, abx, 1, vent], diabetic_idx=0).get_state_idx()\n", + " s1_1 = State(state_categs=[hr, 1, o2, glu, abx, 1, vent], diabetic_idx=0).get_state_idx()\n", + " s2_1 = State(state_categs=[hr, 2, o2, glu, abx, 1, vent], diabetic_idx=0).get_state_idx()\n", + " sbp_H2N2L_wp005[s0_0, s0_0] = 1\n", + " sbp_H2N2L_wp005[s0_1, s0_1] = 1\n", + " sbp_H2N2L_wp005[s1_0, s1_0] = 1\n", + " sbp_H2N2L_wp005[s1_1, s1_1] = 0.95\n", + " sbp_H2N2L_wp005[s1_1, s0_1] = 0.05\n", + " sbp_H2N2L_wp005[s2_0, s2_0] = 1\n", + " sbp_H2N2L_wp005[s2_1, s2_1] = 0.95\n", + " sbp_H2N2L_wp005[s2_1, s1_1] = 0.05\n", + "\n", + "assert np.isclose(sbp_H2N2L_wp005.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "vasopressor_off = np.block([\n", + " [sbp_H2N2L_wp01 @ vaso_off, np.zeros((nS, nS))],\n", + " [np.zeros((nS, nS)), sbp_H2N2L_wp005 @ vaso_off],\n", + "])\n", + "assert np.isclose(vasopressor_off.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fluctuate\n", + "- all (non-treatment) states fluctuate +/- 1 w.p. .1\n", + "- exception: glucose flucuates +/- 1 w.p. .3 if diabetic" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### hr" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "# (abx != 1) && (s_abx != 1)\n", + "hr_fluctuate = np.zeros((nS, nS))\n", + "for (sbp, o2, glu, abx, vent, vaso, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['hr']]\n", + "):\n", + " s0 = State(state_categs=[0, sbp, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[1, sbp, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2 = State(state_categs=[2, sbp, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s0_0 = State(state_categs=[0, sbp, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1_0 = State(state_categs=[1, sbp, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2_0 = State(state_categs=[2, sbp, o2, glu, 0, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s0_1 = State(state_categs=[0, sbp, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1_1 = State(state_categs=[1, sbp, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2_1 = State(state_categs=[2, sbp, o2, glu, 1, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " hr_fluctuate[s0_0, s0_0] = 0.9\n", + " hr_fluctuate[s0_0, s1_0] = 0.1\n", + " hr_fluctuate[s1_0, s0_0] = 0.1\n", + " hr_fluctuate[s1_0, s1_0] = 0.8\n", + " hr_fluctuate[s1_0, s2_0] = 0.1\n", + " hr_fluctuate[s2_0, s1_0] = 0.1\n", + " hr_fluctuate[s2_0, s2_0] = 0.9\n", + " \n", + " hr_fluctuate[s0_1, s0_1] = 1\n", + " hr_fluctuate[s1_1, s1_1] = 1\n", + " hr_fluctuate[s2_1, s2_1] = 1\n", + "\n", + "assert np.isclose(hr_fluctuate.sum(axis=1), 1).all()\n", + "hr_fluctuate = np.block([[hr_fluctuate, np.zeros((nS, nS))], [np.zeros((nS, nS)), hr_fluctuate]])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### sbp" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "# vaso withdrawn affects sbp (non-diabetic)\n", + "# N->L, H->N wp 0.1\n", + "sbp_fluctuate = np.zeros((nS, nS))\n", + "for (hr, o2, glu, abx, vent, vaso, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['sbp']]\n", + "):\n", + " s0 = State(state_categs=[hr, 0, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[hr, 1, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2 = State(state_categs=[hr, 2, o2, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " \n", + " s0_00 = State(state_categs=[hr, 0, o2, glu, 0, 0, vent], diabetic_idx=0).get_state_idx()\n", + " s1_00 = State(state_categs=[hr, 1, o2, glu, 0, 0, vent], diabetic_idx=0).get_state_idx()\n", + " s2_00 = State(state_categs=[hr, 2, o2, glu, 0, 0, vent], diabetic_idx=0).get_state_idx()\n", + " s0_01 = State(state_categs=[hr, 0, o2, glu, 0, 1, vent], diabetic_idx=0).get_state_idx()\n", + " s1_01 = State(state_categs=[hr, 1, o2, glu, 0, 1, vent], diabetic_idx=0).get_state_idx()\n", + " s2_01 = State(state_categs=[hr, 2, o2, glu, 0, 1, vent], diabetic_idx=0).get_state_idx()\n", + " \n", + " s0_10 = State(state_categs=[hr, 0, o2, glu, 1, 0, vent], diabetic_idx=0).get_state_idx()\n", + " s1_10 = State(state_categs=[hr, 1, o2, glu, 1, 0, vent], diabetic_idx=0).get_state_idx()\n", + " s2_10 = State(state_categs=[hr, 2, o2, glu, 1, 0, vent], diabetic_idx=0).get_state_idx()\n", + " s0_11 = State(state_categs=[hr, 0, o2, glu, 1, 1, vent], diabetic_idx=0).get_state_idx()\n", + " s1_11 = State(state_categs=[hr, 1, o2, glu, 1, 1, vent], diabetic_idx=0).get_state_idx()\n", + " s2_11 = State(state_categs=[hr, 2, o2, glu, 1, 1, vent], diabetic_idx=0).get_state_idx()\n", + " \n", + " sbp_fluctuate[s0_01, s0_01] = 1\n", + " sbp_fluctuate[s1_01, s1_01] = 1\n", + " sbp_fluctuate[s2_01, s2_01] = 1\n", + " \n", + " sbp_fluctuate[s0_10, s0_10] = 1\n", + " sbp_fluctuate[s1_10, s1_10] = 1\n", + " sbp_fluctuate[s2_10, s2_10] = 1\n", + " \n", + " sbp_fluctuate[s0_11, s0_11] = 1\n", + " sbp_fluctuate[s1_11, s1_11] = 1\n", + " sbp_fluctuate[s2_11, s2_11] = 1\n", + " \n", + " sbp_fluctuate[s0_00, s0_00] = 0.9\n", + " sbp_fluctuate[s0_00, s1_00] = 0.1\n", + " sbp_fluctuate[s1_00, s0_00] = 0.1\n", + " sbp_fluctuate[s1_00, s1_00] = 0.8\n", + " sbp_fluctuate[s1_00, s2_00] = 0.1\n", + " sbp_fluctuate[s2_00, s1_00] = 0.1\n", + " sbp_fluctuate[s2_00, s2_00] = 0.9\n", + "\n", + "assert np.isclose(sbp_fluctuate.sum(axis=1), 1).all()\n", + "sbp_fluctuate = np.block([[sbp_fluctuate, np.zeros((nS, nS))], [np.zeros((nS, nS)), sbp_fluctuate]])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### o2" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "# vent withdrawn affects o2: normal->low wp 0.1\n", + "o2_fluctuate = np.zeros((nS, nS))\n", + "for (hr, sbp, glu, abx, vent, vaso, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['o2']]\n", + "):\n", + " s0 = State(state_categs=[hr, sbp, 0, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[hr, sbp, 1, glu, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s0_0 = State(state_categs=[hr, sbp, 0, glu, abx, vaso, 0], diabetic_idx=0).get_state_idx()\n", + " s1_0 = State(state_categs=[hr, sbp, 1, glu, abx, vaso, 0], diabetic_idx=0).get_state_idx()\n", + " s0_1 = State(state_categs=[hr, sbp, 0, glu, abx, vaso, 1], diabetic_idx=0).get_state_idx()\n", + " s1_1 = State(state_categs=[hr, sbp, 1, glu, abx, vaso, 1], diabetic_idx=0).get_state_idx()\n", + " o2_fluctuate[s0_0, s0_0] = 0.9\n", + " o2_fluctuate[s0_0, s1_0] = 0.1\n", + " o2_fluctuate[s1_0, s0_0] = 0.1\n", + " o2_fluctuate[s1_0, s1_0] = 0.9\n", + " \n", + " o2_fluctuate[s0_1, s0_1] = 1\n", + " o2_fluctuate[s1_1, s1_1] = 1\n", + "\n", + "assert np.isclose(o2_fluctuate.sum(axis=1), 1).all()\n", + "o2_fluctuate = np.block([[o2_fluctuate, np.zeros((nS, nS))], [np.zeros((nS, nS)), o2_fluctuate]])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### glu" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "# non-diabetic wp 0.1\n", + "glu_fluctuate_01 = np.zeros((nS, nS))\n", + "for (hr, sbp, o2, abx, vent, vaso, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['glu']]\n", + "):\n", + " s0 = State(state_categs=[hr, sbp, o2, 0, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[hr, sbp, o2, 1, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2 = State(state_categs=[hr, sbp, o2, 2, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s3 = State(state_categs=[hr, sbp, o2, 3, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s4 = State(state_categs=[hr, sbp, o2, 4, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " glu_fluctuate_01[s0, s0] = 0.9\n", + " glu_fluctuate_01[s0, s1] = 0.1\n", + " glu_fluctuate_01[s1, s0] = 0.1\n", + " glu_fluctuate_01[s1, s1] = 0.8\n", + " glu_fluctuate_01[s1, s2] = 0.1\n", + " glu_fluctuate_01[s2, s1] = 0.1\n", + " glu_fluctuate_01[s2, s2] = 0.8\n", + " glu_fluctuate_01[s2, s3] = 0.1\n", + " glu_fluctuate_01[s3, s2] = 0.1\n", + " glu_fluctuate_01[s3, s3] = 0.8\n", + " glu_fluctuate_01[s3, s4] = 0.1\n", + " glu_fluctuate_01[s4, s3] = 0.1\n", + " glu_fluctuate_01[s4, s4] = 0.9\n", + "\n", + "assert np.isclose(glu_fluctuate_01.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "# diabetic wp 0.3\n", + "glu_fluctuate_03 = np.zeros((nS, nS))\n", + "for (hr, sbp, o2, abx, vent, vaso, _) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['glu']]\n", + "):\n", + " s0 = State(state_categs=[hr, sbp, o2, 0, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s1 = State(state_categs=[hr, sbp, o2, 1, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s2 = State(state_categs=[hr, sbp, o2, 2, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s3 = State(state_categs=[hr, sbp, o2, 3, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " s4 = State(state_categs=[hr, sbp, o2, 4, abx, vaso, vent], diabetic_idx=0).get_state_idx()\n", + " glu_fluctuate_03[s0, s0] = 0.7\n", + " glu_fluctuate_03[s0, s1] = 0.3\n", + " glu_fluctuate_03[s1, s0] = 0.3\n", + " glu_fluctuate_03[s1, s1] = 0.4\n", + " glu_fluctuate_03[s1, s2] = 0.3\n", + " glu_fluctuate_03[s2, s1] = 0.3\n", + " glu_fluctuate_03[s2, s2] = 0.4\n", + " glu_fluctuate_03[s2, s3] = 0.3\n", + " glu_fluctuate_03[s3, s2] = 0.3\n", + " glu_fluctuate_03[s3, s3] = 0.4\n", + " glu_fluctuate_03[s3, s4] = 0.3\n", + " glu_fluctuate_03[s4, s3] = 0.3\n", + " glu_fluctuate_03[s4, s4] = 0.7\n", + "\n", + "assert np.isclose(glu_fluctuate_03.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "glu_fluctuate = np.block([\n", + " [glu_fluctuate_01, np.zeros((nS, nS))],\n", + " [np.zeros((nS, nS)), glu_fluctuate_03],\n", + "])\n", + "assert np.isclose(glu_fluctuate.sum(axis=1), 1).all()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Assemble Transition Matrix (A,S,S)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "# abx, vaso, vent\n", + "transition_000 = (antibiotics_off @ ventilation_off @ vasopressor_off).T @ (hr_fluctuate @ sbp_fluctuate @ o2_fluctuate @ glu_fluctuate).T\n", + "\n", + "transition_100 = (antibiotics_on @ ventilation_off @ vasopressor_off).T @ (o2_fluctuate @ glu_fluctuate).T\n", + "transition_010 = (antibiotics_off @ ventilation_on @ vasopressor_off).T @ (hr_fluctuate @ sbp_fluctuate @ glu_fluctuate).T\n", + "transition_001 = (antibiotics_off @ ventilation_off @ vasopressor_on).T @ (hr_fluctuate @ o2_fluctuate).T\n", + "\n", + "transition_110 = (antibiotics_on @ ventilation_on @ vasopressor_off).T @ (glu_fluctuate).T\n", + "transition_011 = (antibiotics_off @ ventilation_on @ vasopressor_on).T @ (hr_fluctuate).T\n", + "transition_101 = (antibiotics_on @ ventilation_off @ vasopressor_on).T @ (o2_fluctuate).T\n", + "\n", + "transition_111 = (antibiotics_on @ ventilation_on @ vasopressor_on).T" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "transition_matrix = np.array([\n", + " transition_000.T,\n", + " transition_001.T,\n", + " transition_010.T,\n", + " transition_011.T,\n", + " transition_100.T,\n", + " transition_101.T,\n", + " transition_110.T,\n", + " transition_111.T,\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "assert np.isclose(transition_matrix.sum(axis=2), 1).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[0.6561 , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0.729 , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0.729 , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " ...,\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", + "\n", + " [[0. , 0. , 0.243 , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0.27 , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0.243 , ..., 0. , 0. , 0. ],\n", + " ...,\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", + "\n", + " [[0. , 0.2187 , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0.2187 , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0.243 , 0. , ..., 0. , 0. , 0. ],\n", + " ...,\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ]],\n", + "\n", + " ...,\n", + "\n", + " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " ...,\n", + " [0. , 0. , 0. , ..., 0. , 0.4275 , 0. ],\n", + " [0. , 0. , 0. , ..., 0. , 0.4275 , 0. ],\n", + " [0. , 0. , 0. , ..., 0. , 0.4275 , 0. ]],\n", + "\n", + " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " ...,\n", + " [0. , 0. , 0. , ..., 0.175 , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0.16625, 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0.16625, 0. , 0. ]],\n", + "\n", + " [[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " ...,\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0.475 ],\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0.475 ],\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0.475 ]]])" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "transition_matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(8, 1440, 1440)" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "transition_matrix.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "transition_matrix_absorbing = np.zeros((nA, nS*2+2, nS*2+2))\n", + "transition_matrix_absorbing[:, :nS*2, :nS*2] = transition_matrix\n", + "transition_matrix_absorbing[:, -2, -2] = 1\n", + "transition_matrix_absorbing[:, -1, -1] = 1\n", + "for s in range(nS):\n", + " if reward_per_state[s] == -1:\n", + " transition_matrix_absorbing[:, s, :] = 0\n", + " transition_matrix_absorbing[:, s, -2] = 1\n", + " transition_matrix_absorbing[:, s+nS, :] = 0\n", + " transition_matrix_absorbing[:, s+nS, -2] = 1\n", + " elif reward_per_state[s] == 1:\n", + " transition_matrix_absorbing[:, s, :] = 0\n", + " transition_matrix_absorbing[:, s, -1] = 1\n", + " transition_matrix_absorbing[:, s+nS, :] = 0\n", + " transition_matrix_absorbing[:, s+nS, -1] = 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initial State Distribution (S,)" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "prior_initial_state = np.zeros(nS*2)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "diab_prior = [0.8, 0.2]\n", + "hr_prior = [0.25, 0.5, 0.25]\n", + "sbp_prior = [0.25, 0.5, 0.25]\n", + "o2_prior = [0.2, 0.8]\n", + "glu_prior = [\n", + " [0.05, 0.15, 0.6, 0.15, 0.05], # non-diabetic\n", + " [0.01, 0.05, 0.15, 0.6, 0.19], # diabetic\n", + "]\n", + "abx, vent, vaso = (0,0,0)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "for (hr, sbp, o2, glu, diab) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables if key not in ['abx', 'vaso', 'vent']]\n", + "):\n", + " s = State(state_categs=[hr, sbp, o2, glu, 0, 0, 0], diabetic_idx=diab).get_state_idx('full')\n", + " prior_initial_state[s] = \\\n", + " diab_prior[diab] * hr_prior[hr] * sbp_prior[sbp] * o2_prior[o2] * glu_prior[diab][glu]" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1440,)" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prior_initial_state.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.0005, 0. , 0. , ..., 0. , 0. , 0. ])" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prior_initial_state" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "prior_initial_state_absorbing = np.zeros(nS*2+2)\n", + "prior_initial_state_absorbing[:nS*2] = prior_initial_state\n", + "prior_initial_state_absorbing[[*(reward_per_state != 0), *(reward_per_state != 0), True, True]] = 0 # do not start in an almost-terminal state\n", + "prior_initial_state_absorbing = prior_initial_state_absorbing / prior_initial_state_absorbing.sum() # renormalize" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1442,)" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prior_initial_state_absorbing.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "74" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(prior_initial_state_absorbing > 0).sum()" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([376]),)" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.where(reward_per_state == 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([1, 1, 1, 2, 0, 0, 0])" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "State(1096, idx_type='full').get_state_vector()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Modified Initial State Distribution (S,)" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [], + "source": [ + "prior_initial_state = np.zeros(nS*2)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [], + "source": [ + "diab_prior = [0.8, 0.2]\n", + "hr_prior = [0.25, 0.5, 0.25]\n", + "sbp_prior = [0.25, 0.5, 0.25]\n", + "o2_prior = [0.2, 0.8]\n", + "glu_prior = [\n", + " [0.05, 0.15, 0.6, 0.15, 0.05], # non-diabetic\n", + " [0.01, 0.05, 0.15, 0.6, 0.19], # diabetic\n", + "]\n", + "abx_prior, vent_prior, vaso_prior = [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [], + "source": [ + "for (hr, sbp, o2, glu, abx, vaso, vent, diab) in itertools.product(\n", + " *[state_variable_values[key] for key in state_variables]\n", + "):\n", + " s = State(state_categs=[hr, sbp, o2, glu, abx, vaso, vent], diabetic_idx=diab).get_state_idx('full')\n", + " prior_initial_state[s] = \\\n", + " diab_prior[diab] * hr_prior[hr] * sbp_prior[sbp] * o2_prior[o2] * glu_prior[diab][glu] * abx_prior[abx] * vent_prior[vent] * vaso_prior[vaso]" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1440,)" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prior_initial_state.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([6.250e-05, 6.250e-05, 6.250e-05, ..., 2.375e-04, 2.375e-04,\n", + " 2.375e-04])" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prior_initial_state" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [], + "source": [ + "prior_initial_state_absorbing = np.zeros(nS*2+2)\n", + "prior_initial_state_absorbing[:nS*2] = prior_initial_state\n", + "prior_initial_state_absorbing[[*(reward_per_state != 0), *(reward_per_state != 0), True, True]] = 0 # do not start in an almost-terminal state\n", + "prior_initial_state_absorbing = prior_initial_state_absorbing / prior_initial_state_absorbing.sum() # renormalize" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1442,)" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prior_initial_state_absorbing.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "606" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(prior_initial_state_absorbing > 0).sum()" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([376]),)" + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.where(reward_per_state == 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['../data/modified_prior_initial_state_absorbing.joblib']" + ] + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "joblib.dump(prior_initial_state_absorbing, '../data/modified_prior_initial_state_absorbing.joblib')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Save" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [], + "source": [ + "MDP_parameters = {\n", + " 'transition_matrix': transition_matrix,\n", + " 'transition_matrix_absorbing': transition_matrix_absorbing,\n", + " 'reward_per_state': reward_per_state,\n", + " 'reward_matrix_ASS': reward_matrix_ASS,\n", + " 'reward_matrix_absorbing_SA': reward_matrix_absorbing_SA,\n", + " 'reward_matrix_absorbing_ASS': reward_matrix_absorbing_ASS,\n", + " 'prior_initial_state': prior_initial_state,\n", + " 'prior_initial_state_absorbing': prior_initial_state_absorbing, \n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['../data/MDP_parameters.joblib']" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "joblib.dump(MDP_parameters, '../data/MDP_parameters.joblib')" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['../data/prior_initial_state_absorbing.joblib']" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "joblib.dump(prior_initial_state_absorbing, '../data/prior_initial_state_absorbing.joblib')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "RL_venv", + "language": "python", + "name": "rl_venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/sepsisSim/data-prep/1-ground-truth-policy.ipynb b/sepsisSim/data-prep/1-ground-truth-policy.ipynb new file mode 100644 index 0000000..bce29b2 --- /dev/null +++ b/sepsisSim/data-prep/1-ground-truth-policy.ipynb @@ -0,0 +1,3303 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib\n", + "matplotlib.rcParams['font.sans-serif'] = ['Arial']\n", + "%config InlineBackend.figure_formats = ['svg']" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import joblib\n", + "import pickle\n", + "import copy\n", + "import itertools\n", + "from tqdm import tqdm\n", + "from joblib import Parallel, delayed\n", + "import mdptoolbox.mdp as mdptools" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Ground truth MDP model\n", + "MDP_parameters = joblib.load('../data/MDP_parameters.joblib')\n", + "P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next)\n", + "R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A)\n", + "nS, nA = R.shape\n", + "gamma = 0.99" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# unif rand isd, mixture of diabetic state\n", + "PROB_DIAB = 0.2\n", + "isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib')\n", + "isd = (isd > 0).astype(float)\n", + "isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB)\n", + "isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1442, 8)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nS, nA" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Helper functions" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# Map back to the policy in discrete MDP\n", + "def convert_to_policy_table(pi):\n", + " pol = np.zeros((nS, nA))\n", + " pol[list(np.arange(nS-2)), pi[:-2]] = 1\n", + " pol[-2:, 0] = 1\n", + " return pol" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "def policy_eval_analytic(P, R, pi, gamma):\n", + " \"\"\"\n", + " Given the MDP model (transition probability P (S,A,S) and reward function R (S,A)),\n", + " Compute the value function of a policy using matrix inversion\n", + " \n", + " V_π = (I - γ P_π)^-1 R_π\n", + " \"\"\"\n", + " nS, nA = R.shape\n", + " R_pi = np.sum(R * pi, axis=1)\n", + " P_pi = np.sum(P * np.expand_dims(pi, 2), axis=1)\n", + " V_pi = np.linalg.inv(np.eye(nS) - gamma * P_pi) @ R_pi\n", + " return V_pi" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# For plotting\n", + "def conv_to_np(this_list):\n", + " this_arr = np.array(this_list)[:, np.newaxis]\n", + " # Make this idempotent\n", + " this_arr = this_arr.squeeze()[:, np.newaxis]\n", + " return this_arr\n", + "\n", + "def simulator_eval_on_policy(π):\n", + " NSTEPS = 100\n", + " PROB_DIAB = 0.2\n", + " DISCOUNT = 0.99\n", + " USE_BOOSTRAP=True\n", + " N_BOOTSTRAP = 100\n", + " \n", + " # Get the true RL reward as a sanity check\n", + " # Note that the RL policy includes actions for \"death\" and \"discharge\" absorbing states, which we ignore by taking [:-2, :]\n", + " from sepsisSimDiabetes.DataGenerator import DataGenerator\n", + " import cf.counterfactual as cf\n", + " import cf.utils as utils\n", + " \n", + " np.random.seed(90000)\n", + " dgen = DataGenerator()\n", + " NSIMSAMPS_RL = 1000\n", + " states_full_rl, actions_full_rl, lengths_full_rl, rewards_full_rl, diab_full_rl, _, _ = dgen.simulate(\n", + " NSIMSAMPS_RL, NSTEPS, policy=π[:-2, :], policy_idx_type='full', \n", + " p_diabetes=PROB_DIAB, modified=True, use_tqdm=False) #True, tqdm_desc='RL Policy Simulation')\n", + "\n", + " obs_samps_full_rlpol = utils.format_dgen_samps(\n", + " states_full_rl, actions_full_rl, rewards_full_rl, diab_full_rl, NSTEPS, NSIMSAMPS_RL)\n", + "\n", + " this_true_full_rl_reward = cf.eval_on_policy(\n", + " obs_samps_full_rlpol, discount=DISCOUNT, \n", + " bootstrap=USE_BOOSTRAP, n_bootstrap=N_BOOTSTRAP) # Need a second axis to concat later\n", + " \n", + " return conv_to_np([this_true_full_rl_reward])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Planning" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Policy Iteration" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "PI = mdptools.PolicyIteration(P, R, discount=gamma)\n", + "PI.run()\n", + "V_star_PI = np.array(PI.V)\n", + "π_star_PI = np.array(PI.policy)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Value Iteration" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "VI = mdptools.ValueIteration(P, R, discount=gamma, epsilon=1e-10)\n", + "VI.run()\n", + "V_star_VI = np.array(VI.V)\n", + "π_star_VI = np.array(VI.policy)\n", + "\n", + "# re-evalute the learned policy\n", + "pi_star_VI = convert_to_policy_table(π_star_VI)\n", + "V_π_star_PE = policy_eval_analytic(P.transpose((1,0,2)), R, pi_star_VI, gamma)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plot and save" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2021-04-05T11:34:10.026755\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(V_star_VI, bins=50, alpha=0.3, label='VI values')\n", + "plt.hist(V_π_star_PE, bins=50, alpha=0.3, label='PE values')\n", + "plt.hist(V_star_PI, bins=50, alpha=0.3, label='PI values')\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.6775418119368786" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "isd @ V_star_PI" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.all(π_star_VI == π_star_PI)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['../data/π_star.joblib']" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "joblib.dump(π_star_PI, '../data/π_star_PI.joblib')\n", + "joblib.dump(V_star_PI, '../data/V_star_PI.joblib')\n", + "\n", + "joblib.dump(π_star_VI, '../data/π_star_VI.joblib')\n", + "joblib.dump(V_star_VI, '../data/V_star_VI.joblib')\n", + "joblib.dump(V_π_star_PE, '../data/V_π_star_PE.joblib')\n", + "\n", + "joblib.dump(pi_star_VI, '../data/π_star.joblib')" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "true_rewards_list = [simulator_eval_on_policy(π) for π in [pi_star_VI]]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2021-04-05T11:34:14.594709\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.6641307574316647\n" + ] + } + ], + "source": [ + "reward_df = pd.DataFrame(np.concatenate(true_rewards_list, axis=1))\n", + "sns.boxplot(data=reward_df, whis=[2.5, 97.5], width=0.5, linewidth=1)\n", + "plt.ylabel(\"Average Reward\")\n", + "plt.xlabel('iteration')\n", + "plt.show()\n", + "print(reward_df.median().item())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "RL_venv", + "language": "python", + "name": "rl_venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/sepsisSim/data-prep/cf/__init__.py b/sepsisSim/data-prep/cf/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sepsisSim/data-prep/cf/counterfactual.py b/sepsisSim/data-prep/cf/counterfactual.py new file mode 100644 index 0000000..94e831f --- /dev/null +++ b/sepsisSim/data-prep/cf/counterfactual.py @@ -0,0 +1,579 @@ +""" +counterfactual +""" +import numpy as np +import mdptoolbox.mdp as mdptools +import warnings +import cf.gumbelTools as gt +from tqdm import tqdm_notebook as tqdm + +class MatrixMDP(object): + def __init__(self, tx_mat, r_mat, p_initial_state=None, p_mixture=None): + """__init__ + + :param tx_mat: Transition matrix of shape (n_components x n_actions x + n_states x n_states) or (n_actions x n_states x n_states) + :param r_mat: Reward matrix of shape (n_components x n_actions x + n_states x n_states) or (n_actions x n_states x n_states) + :param p_initial_state: Probability over initial states + :param p_mixture: Probability over "mixture" components, in this case + diabetes status + """ + # QA the size of the inputs + assert tx_mat.ndim == 4 or tx_mat.ndim == 3, \ + "Transition matrix wrong dims ({} != 3 or 4)".format(tx_mat.ndim) + assert r_mat.ndim == 4 or r_mat.ndim == 3, \ + "Reward matrix wrong dims ({} != 3 or 4)".format(tx_mat.ndim) + assert r_mat.shape == tx_mat.shape, \ + "Transition / Reward matricies not the same shape!" + assert tx_mat.shape[-1] == tx_mat.shape[-2], \ + "Last two dims of Tx matrix should be equal to num of states" + + # Get the number of actions and states + n_actions = tx_mat.shape[-3] + n_states = tx_mat.shape[-2] + + # Get the number of components in the mixture: + # If no hidden component, add a dummy so the rest of the interface works + if tx_mat.ndim == 3: + n_components = 1 + tx_mat = tx_mat[np.newaxis, ...] + r_mat = r_mat[np.newaxis, ...] + else: + n_components = tx_mat.shape[0] + + # Get the prior over initial states + if p_initial_state is not None: + if p_initial_state.ndim == 1: + p_initial_state = p_initial_state[np.newaxis, :] + + assert p_initial_state.shape == (n_components, n_states), \ + ("Prior over initial state is wrong shape " + "{} != (C x S)").format(p_initial_state.shape) + + # Get the prior over components + if n_components == 1: + p_mixture = np.array([1.0]) + elif p_mixture is not None: + assert p_mixture.shape == (n_components, ), \ + ("Prior over components is wrong shape " + "{} != (C)").format(p_mixture.shape) + + self.n_components = n_components + self.n_actions = n_actions + self.n_states = n_states + self.tx_mat = tx_mat + self.r_mat = r_mat + self.p_initial_state = p_initial_state + self.p_mixture = p_mixture + + self.current_state = None + self.component = None + + def reset(self): + """reset + + Reset the environment, and return the initial position + + :returns: Tuple of (initial state, component) + """ + # Draw from the mixture + if self.p_mixture is None: + self.component = np.random.randint(self.n_components) + else: + self.component = np.random.choice( + self.n_components, size=1, p=self.p_mixture.tolist())[0] + + # Draw an initial state + if self.p_initial_state is None: + self.current_state = np.random.randint(self.n_states) + else: + self.current_state = np.random.choice( + self.n_states, size=1, + p=self.p_initial_state[self.component, :].squeeze().tolist())[0] + + return self.current_state, self.component + + def step(self, action): + """step + + Take a step with the given action + + :action: Integer of the action + :returns: Tuple of (next_state, reward) + """ + assert action in range(self.n_actions), "Invalid action!" + is_term = False + + next_prob = self.tx_mat[ + self.component, action, self.current_state, + :].squeeze() + + assert np.isclose(next_prob.sum(), 1), "Probs do not sum to 1!" + + next_state = np.random.choice(self.n_states, size=1, p=next_prob)[0] + + reward = self.r_mat[self.component, action, + self.current_state, next_state] + self.current_state = next_state + + # In this MDP, rewards are only received at the terminal state + if reward != 0: + is_term = True + + return self.current_state, reward, is_term + + def policyIteration(self, discount=0.9, obs_pol=None, skip_check=False, + eval_type=1, return_raw=False): + """Calculate the optimal policy for the marginal tx_mat and r_mat, + using policy iteration from pymdptoolbox + + Note that this function marginalizes over any mixture components if + they exist. + + :discount: Discount factor for rewards + :returns: Policy matrix with deterministic policy + + """ + # Define the marginalized transition and reward matrix + r_mat_obs = self.r_mat.T.dot(self.p_mixture).T + tx_mat_obs = self.tx_mat.T.dot(self.p_mixture).T + + # Run Policy Iteration + pi = mdptools.PolicyIteration( + tx_mat_obs, r_mat_obs, discount=discount, skip_check=skip_check, + policy0=obs_pol, eval_type=eval_type) + pi.setSilent() + pi.run() + + # Convert this (deterministic) policy pi into a matrix format + pol_opt = np.zeros((self.n_states, self.n_actions)) + pol_opt[np.arange(len(pi.policy)), pi.policy] = 1 + + if return_raw: + return pol_opt, pi + return pol_opt + + def policyEval(self, pi, discount=0.9, theta=1e-12): + # Define the marginalized transition and reward matrix + r_mat_obs = self.r_mat.T.dot(self.p_mixture).T + tx_mat_obs = self.tx_mat.T.dot(self.p_mixture).T + + nS = self.tx_mat.shape[-2] + nA = self.tx_mat.shape[-3] + P = tx_mat_obs.transpose((1,0,2)) + R = np.zeros((nS, nA)) + gamma = discount + for s in range(nS): + for a in range(nA): + R[s,a] = tx_mat_obs[a, s, :] @ r_mat_obs[a, s, :] + + R_pi = np.sum(R * pi, axis=1) + P_pi = np.sum(P * np.expand_dims(pi, 2), axis=1) + V_pi = np.linalg.inv(np.eye(nS) - gamma * P_pi) @ R_pi + return V_pi + +def policy_eval_analytic(env, pi, gamma): + """ + Given the MDP model (transition probability P and reward function R), + Compute the value function of a policy using matrix inversion + + V_π = (I - γ P_π)^-1 R_π + """ + nS, nA = env.nS, env.nA + P = env.p_transition + R = env.p_reward + R_pi = np.sum(R * pi, axis=1) + P_pi = np.sum(P * np.expand_dims(pi, 2), axis=1) + V_pi = np.linalg.inv(np.eye(nS) - gamma * P_pi) @ R_pi + return V_pi + +class BatchSampler(object): + """BatchSampler + + Samples batches of episodes + """ + def __init__(self, mdp): + assert isinstance(mdp, MatrixMDP), "mdp argument must be a MatrixMDP" + self.mdp = mdp + + def on_policy_sample(self, policy=None, n_steps=10, n_samps=1, out='array', + use_tqdm=False, tqdm_desc=''): + """on_policy_sample. + + :param policy: Stochastic matrix of size (n_states x n_actions), default is random policy + :param n_steps: Maximum length of an episode + :param n_samps: Number of episodes in the batch + :param out: (Not implemented) type of output, must be 'array' for now + :param use_tqdm: Whether or not to display progress bars + :param tqdm_desc: Description for progress bars + :returns: Array containing samples collected under the policy + """ + if policy is not None: + assert policy.shape == (self.mdp.n_states, self.mdp.n_actions), \ + "Policy is the wrong shape. {} != S x A".format(policy.shape) + + # For each trajectory, for each step, we record + # t, A_{t}, O_{t}, O_{t+1}, h_{t}, h_{t+1}, R_{t} + # Note that in the toy example, "h" corresponds to the hidden component + + assert out == 'array', "Only 'array' supported as output type for now" + result = np.zeros((n_samps, n_steps, 7)) + result[:, :, 1:4] = -1 # Placeholder for tracking the end of the seq + + for samp_idx in tqdm(range(n_samps), + disable=not(use_tqdm), desc=tqdm_desc): + current_state, component = self.mdp.reset() + + # Sample the trajectory + for time_idx in range(n_steps): + if policy is None: # Random Policy + this_action = np.random.randint(self.mdp.n_actions) + else: + this_action = np.random.choice( + self.mdp.n_actions, size=1, + p=policy[current_state, :].squeeze().tolist())[0] + + # Terminal state if the reward is nonzero + next_state, this_reward, is_term = self.mdp.step(this_action) + + # Record State + result[samp_idx, time_idx] = ( + time_idx, + this_action, + current_state, + next_state, + component, + component, + this_reward) + + current_state = next_state + if is_term: + break + + return result + + def cf_trajectory(self, batch, cf_policy, n_cf_samps=1, + use_tqdm=False, tqdm_desc=''): + """cf_trajectory + + :param batch: Output of the sampler, shape is (n_samps, n_steps, 7) + :param cf_policy: Counterfactual policy to evaluate + :param n_cf_samps: Counterfactual samples to draw per episode + :param use_tqdm: Whether or not to display progress bars + :param tqdm_desc: Description for progress bars + + :returns: Array containing counterfactual trajectories + """ + + # Used for Monte Carlo sampling + n_draws = 1000 + + # For each trajectory, for each step, we record + # t, A_{t}, O_{t}, O_{t+1}, h_{t}, h_{t+1}, R_{t} + # Note that in the toy example, "h" corresponds to the hidden component + n_obs_eps = batch.shape[0] + n_obs_steps = batch.shape[1] + + # Result matrix has an extra dimension for number of CF draws per OBS + result = np.zeros((n_obs_eps, n_cf_samps, n_obs_steps, 7)) + result[:, :, :, 0] = np.arange(n_obs_steps) + result[:, :, :, 1:4] = -1 # Placeholders for end of sequence + + # Take posterior over the mixture components in batch form + # NOTE: This code does not serve a purpose in our current toy example, + # because we define the MDP with a single component, but it could be + # used in a future experiment with a single time-independent confounder + if self.mdp.n_components == 1: + mx_posterior = np.ones((n_obs_eps, 1)) + else: + mx_posterior = self.mixture_posterior(batch) + + for obs_samp_idx in tqdm(range(n_obs_eps), disable=not(use_tqdm), desc=tqdm_desc): + for cf_samp_idx in range(n_cf_samps): + obs_actions = batch[obs_samp_idx, :, 1].astype(int).squeeze().tolist() + obs_from_states = batch[obs_samp_idx, :, 2].astype(int).squeeze().tolist() + obs_to_states = batch[obs_samp_idx, :, 3].astype(int).squeeze().tolist() + + # Same initial state + current_state = obs_from_states[0] + + # Infer / Sample from the mixture posterior + this_mx_posterior = mx_posterior[obs_samp_idx].tolist() + component = np.random.choice( + self.mdp.n_components, size=1, p=this_mx_posterior) + + for time_idx in range(n_obs_steps): + obs_action = obs_actions[time_idx] + + if cf_policy is None: # Random Policy + cf_action = np.random.randint(self.mdp.n_actions) + else: + cf_action = np.random.choice( + self.mdp.n_actions, size=1, + p=cf_policy[current_state, :].squeeze().tolist())[0] + + # Interventional probabilities under new action + new_interv_probs = \ + self.mdp.tx_mat[component, + cf_action, current_state, + :].squeeze().tolist() + + # If observed sequence did not terminate, then infer cf + # probabilities; Otherwise treat this as an interventional + # query (once we're past the final time-step of the + # observed sequence, there is no posterior over latents) + + if obs_action == -1: + cf_probs = new_interv_probs + else: + # Old and new interventional probabilities + prev_interv_probs = \ + self.mdp.tx_mat[component, + obs_action, obs_from_states[time_idx], + :].squeeze().tolist() + + assert prev_interv_probs[obs_to_states[time_idx]] != 0 + + # Infer counterfactual probabilities + cf_probs = tx_posterior( + prev_interv_probs, new_interv_probs, + obs=obs_to_states[time_idx], + n_samp=n_draws).tolist() + + next_state = np.random.choice( + self.mdp.n_states, size=1, p=cf_probs)[0] + this_reward = self.mdp.r_mat[ + component, cf_action, current_state, next_state] + + # Record result + result[obs_samp_idx, cf_samp_idx, time_idx] = ( + time_idx, + cf_action, + current_state, + next_state, + component, + component, + this_reward) + + if this_reward != 0 and time_idx != n_obs_steps - 1: + # Fill in next state, convention in obs_samps + result[obs_samp_idx, cf_samp_idx, time_idx + 1] = ( + time_idx + 1, + -1, + next_state, + -1, + component, + component, + 0) + break + + current_state = next_state + + return result + + def mixture_posterior(self, batch): + """mixture_posterior + Infer the posterior over the mixture components of the MDP + + :param batch: Batch of observed trajectories (n_samps x n_steps x 7) + + :returns: Posterior over mixture components (n_samps x n_components) + """ + n_samps = batch.shape[0] + n_steps = batch.shape[1] + posterior = np.zeros((n_samps, self.mdp.n_components)) + + # Ignore errors due to zeros + with np.errstate(divide='ignore'): + log_p_initial_state = np.log(self.mdp.p_initial_state) + log_p_mixture = np.log(self.mdp.p_mixture) + log_mat = np.log(self.mdp.tx_mat) + + for obs_samp_idx in range(n_samps): + + # Prior + this_log_posterior = log_p_mixture.copy() + + # Recall that batch is of size (n_samps x n_steps x 7) with cols: + # t, A_{t}, O_{t}, O_{t+1}, h_{t}, h_{t+1}, R_{t} + + # Update with likelihood of initial state + this_log_posterior += log_p_initial_state[ + :, batch[obs_samp_idx, 0, 2].astype(int)] + + for time_idx in range(n_steps): + # Stop when we reach the end of the sequence + if batch[obs_samp_idx, time_idx, 1] == -1: + break + # Update likelihood for observed transitions + this_log_posterior += log_mat[ + :, # Across components + batch[obs_samp_idx, time_idx, 1].astype(int), # Action taken + batch[obs_samp_idx, time_idx, 2].astype(int), # From this state + batch[obs_samp_idx, time_idx, 3].astype(int) # To this state + ] + + # Convert to normalized probabilities + this_posterior = np.exp(this_log_posterior) + try: + this_posterior = this_posterior / this_posterior.sum(axis=0) + except RuntimeWarning: + import pdb + pdb.set_trace() + + posterior[obs_samp_idx] = this_posterior + + return posterior + +def tx_posterior(p_c, p_t, obs=0, n_samp=1000): + """tx_posterior + + Get a posterior over counterfactual transitions + + :param p_c: "Control" probabilities, under observed action + :param p_t: "Treatment" probabilities, under different action + :param obs: Observed outcome under observed action + :param n_samp: Number of Monte Carlo samples from posterior + """ + assert isinstance(p_c, list), "Pass probabilities in as a list!" + assert isinstance(p_t, list), "Pass probabilities in as a list!" + + n_cat = len(p_c) + assert len(p_c) == len(p_t) + assert obs in range(n_cat), "Obs is {}, not valid!".format(obs) + np.testing.assert_approx_equal(np.sum(p_c), 1) + np.testing.assert_approx_equal(np.sum(p_t), 1) + + # Define our categorical logits + with np.errstate(divide='ignore'): + logits_control = np.log(np.array(p_c)) + logits_treat = np.log(np.array(p_t)) + + assert p_c[obs] != 0, "Probability of observed event was zero!" + + # Note: These are the Gumbel values (just g), not log p + g + posterior_samp = gt.topdown(logits_control, obs, n_samp) + + # The posterior under control should give us the same result as the obs + assert ((posterior_samp + logits_control).argmax(axis=1) == obs).sum() == n_samp + + # Counterfactual distribution + # This throws a RunTimeWarning because logits_treat includes some -inf, but + # that is expected + posterior_sum = posterior_samp + logits_treat + + # Because some logits are -inf, some entries of posterior_sum will be nan, + # but this is OK - these correspond to zero-probability transitions. We + # just assert here that at least one of the entries for each sample is an + # actual number (read the assert below as: Make sure that none of the + # samples have all NaNs) + assert not np.any(np.all(np.isnan(posterior_sum), axis=1)) + posterior_treat = posterior_sum.argmax(axis=1) + + # Reshape posterior argmax into a 1-D one-hot encoding for each sample + mask = np.zeros(posterior_sum.shape) + mask[np.arange(len(posterior_sum)), posterior_treat] = 1 + posterior_prob = mask.sum(axis=0) / mask.shape[0] + + return posterior_prob + +def calc_reward(obs_samps, discount=0.9): + # Column 0 is a time index, column 6 is the reward + discounted_reward = (discount**obs_samps[..., 0] * obs_samps[..., 6]) + return discounted_reward.sum(axis=-1) # Take the last axis + +def eval_on_policy(obs_samps, discount=0.9, bootstrap=False, n_bootstrap=None): + """eval_on_policy + + :param obs_samps: + :param discount: + :param bootstrap: + :param n_bootstrap: + """ + obs_rewards = calc_reward(obs_samps, discount).squeeze() # 1D array + assert obs_rewards.ndim == 1 + + if bootstrap: + assert n_bootstrap is not None, "Please specify n_bootstrap" + bs_rewards = np.random.choice( + obs_rewards, + size=(n_bootstrap, obs_rewards.shape[0]), + replace=True) + return bs_rewards.mean(axis=1) + else: + return obs_rewards.mean() + +def eval_wis(obs_samps, obs_policy, new_policy, + discount=0.9, bootstrap=False, n_bootstrap=None): + """eval_off_policy + + Weighted Importance Sampling for Off Policy Evaluation + + :obs_samps: Observed samples + :policy: Stochastic policy to evaluate + :returns: Expected returns (scalar) + """ + # Check dimensions + assert obs_policy.ndim == 2 + assert new_policy.ndim == 2 + assert obs_samps.ndim == 3 + + # Precompute the discounted rewards and importance weights + obs_rewards = calc_reward(obs_samps, discount).squeeze() # 1D array + assert obs_rewards.ndim == 1 + + obs_actions = obs_samps[..., 1].astype(int) + obs_states = obs_samps[..., 2].astype(int) + + # NOTE: This fails silently in that action = -1 corresponds to the end + # of a sequence, but in this indexing will just take the last action. + # This is corrected in the next code block "deal with variable length..." + p_obs = obs_policy[obs_states, obs_actions] + p_new = new_policy[obs_states, obs_actions] + + # Deal with variable length sequences by setting ratio to 1 + terminated_idx = obs_actions == -1 + p_obs[terminated_idx] = 1 + p_new[terminated_idx] = 1 + + if not np.all(p_obs > 0): + import pdb + pdb.set_trace() + assert np.all(p_obs > 0), "Some actions had zero prob under p_obs, WIS fails" + + cum_ir = (p_new / p_obs).prod(axis=1) + + wis_idx = (cum_ir > 0) + + if wis_idx.sum() == 0: + print("Found zero matching WIS samples, continuing") + return np.nan, wis_idx, wis_idx.sum() + + if bootstrap: + assert n_bootstrap is not None, "Please specify n_bootstrap" + # Get indices, because we need to sample from cum_ir and rewards + idx = np.random.choice( + np.arange(obs_rewards.shape[0]), + size=(n_bootstrap, obs_rewards.shape[0]), + replace=True) + + # Keepdims so that we can broadcast + wis_bs_samps = (cum_ir[idx] / + cum_ir[idx].mean(axis=1, keepdims=True)) * obs_rewards[idx] + + if np.any(np.isnan(wis_bs_samps)): + import pdb + pdb.set_trace() + + # Return WIS, one per row + wis_est = wis_bs_samps.mean(axis=1) + return wis_est, wis_idx[idx], wis_idx[idx].sum() + else: + wis = (cum_ir / cum_ir.mean()) * obs_rewards + if np.any(np.isnan(wis)): + import pdb + pdb.set_trace() + + wis_est = wis.mean() + return wis_est, wis_idx, wis_idx.sum() diff --git a/sepsisSim/data-prep/cf/gumbelTools.py b/sepsisSim/data-prep/cf/gumbelTools.py new file mode 100644 index 0000000..8c1b817 --- /dev/null +++ b/sepsisSim/data-prep/cf/gumbelTools.py @@ -0,0 +1,53 @@ +''' +Tools for sampling efficiently from a Gumbel posterior + +Original code taken from https://cmaddis.github.io/gumbel-machinery, and then +modified to work as numpy arrays, and to fit our nomenclature, e.g. +* np.log(alpha) is replaced by log probabilities (which we refer to as logits) +* np.log(sum(alphas)) is removed, because it should always equal zero +''' +import numpy as np + +def truncated_gumbel(logit, truncation): + """truncated_gumbel + + :param logit: Location of the Gumbel variable (e.g., log probability) + :param truncation: Value of Maximum Gumbel + """ + # Note: In our code, -inf shows up for zero-probability events, which is + # handled in the topdown function + assert not np.isneginf(logit) + + gumbel = np.random.gumbel(size=(truncation.shape[0])) + logit + trunc_g = -np.log(np.exp(-gumbel) + np.exp(-truncation)) + return trunc_g + +def topdown(logits, k, nsamp=1): + """topdown + + Top-down sampling from the Gumbel posterior + + :param logits: log probabilities of each outcome + :param k: Index of observed maximum + :param nsamp: Number of samples from gumbel posterior + """ + np.testing.assert_approx_equal(np.sum(np.exp(logits)), 1), "Probabilities do not sum to 1" + ncat = logits.shape[0] + + gumbels = np.zeros((nsamp, ncat)) + + # Sample top gumbels + topgumbel = np.random.gumbel(size=(nsamp)) + + for i in range(ncat): + # This is the observed outcome + if i == k: + gumbels[:, k] = topgumbel - logits[i] + # These were the other feasible options (p > 0) + elif not(np.isneginf(logits[i])): + gumbels[:, i] = truncated_gumbel(logits[i], topgumbel) - logits[i] + # These have zero probability to start with, so are unconstrained + else: + gumbels[:, i] = np.random.gumbel(size=nsamp) + + return gumbels diff --git a/sepsisSim/data-prep/cf/utils.py b/sepsisSim/data-prep/cf/utils.py new file mode 100644 index 0000000..35b162b --- /dev/null +++ b/sepsisSim/data-prep/cf/utils.py @@ -0,0 +1,259 @@ +import numpy as np +import pandas as pd +from sepsisSimDiabetes.State import State + +import matplotlib.pyplot as plt +import seaborn as sns +import warnings +from matplotlib.ticker import FormatStrFormatter +warnings.simplefilter(action='ignore', category=FutureWarning) + +def format_dgen_samps(states, actions, rewards, hidden, NSTEPS, NSIMSAMPS): + """format_dgen_samps + Formats the output of the data generator (a batch of trajectories) in a way + that the other functions will consume + + :param states: states + :param actions: actions + :param rewards: rewards + :param hidden: hidden states + :param NSTEPS: Maximum length of trajectory + :param NSIMSAMPS: Number of trajectories + """ + obs_samps = np.zeros((NSIMSAMPS, NSTEPS, 7)) + obs_samps[:, :, 0] = np.arange(NSTEPS) # Time Index + obs_samps[:, :, 1] = actions[:, :, 0] + obs_samps[:, :, 2] = states[:, :-1, 0] # from_states + obs_samps[:, :, 3] = states[:, 1:, 0] # to_states + obs_samps[:, :, 4] = hidden[:, :, 0] # Hidden variable + obs_samps[:, :, 5] = hidden[:, :, 0] # Hidden variable + obs_samps[:, :, 6] = rewards[:, :, 0] + + return obs_samps + +def df_from_samps(samps, pt_idx=0, get_outcome=False, is_proj=False, is_full=False): + """df_from_samps + + Creates a dataframe from samples, selecting a specific patient in a batch, + and formatting in a way that is consumed by our plotting code + + :param samps: Sample trajectories + :param pt_idx: Patient index + :param get_outcome: Boolean, whether or not to return the outcome + :param is_proj: Whether or not this has been projected already + """ + # Find the end of the trajectory, which is one past the time when reward occurs + endtime = samps.shape[1] - 1 # By default, this is the end of the sequence + for t in range(samps.shape[1]): + if samps[pt_idx, t, 1] == -1: # Action = -1 indicates end + endtime = t + break + + # Extract individual arrays + if is_proj: + # For projected samples, want one step back, b/c last state is abs + time = np.arange(endtime) + elif endtime == samps.shape[1] - 1: + time = samps[pt_idx, :, 0].astype(int) + else: + time = np.arange(endtime + 1) # +1 to get inclusive + + states = samps[pt_idx, time, 2] # Go though endtime inclusive + diab_idx = samps[pt_idx, 0, 4] # Scalar + + state_array_2d = np.zeros((time.shape[0], 8)) + for t in time: + state_array_2d[t, 0] = t + if is_full: + this_state = State(state_idx = states[t], idx_type='full') + else: + if is_proj: + if states[t] > 144: + break + this_state = State( + state_idx = states[t], + idx_type='proj_obs', + diabetic_idx=diab_idx) + else: + this_state = State( + state_idx = states[t], + idx_type='obs', + diabetic_idx=diab_idx) + state_array_2d[t, 1:] = this_state.get_state_vector() + + df = pd.DataFrame(state_array_2d, columns = [ + 'Time', + 'Heart Rate', + 'SysBP', + 'Percent O2', + 'Glucose', + 'Treat: AbX', + 'Treat: Vaso', + 'Treat: Vent' + ]) + + # Get the outcome + if get_outcome and not is_proj: + outcome = (endtime, samps[pt_idx, endtime - 1, 6]) + return df, outcome + # The diff with proj is that the last state is at endtime-1 + elif get_outcome and is_proj: + outcome = (endtime - 1, samps[pt_idx, endtime - 1, 6]) + return df, outcome + else: + return df + +def plot_trajectory(samps, pt_idx=0, cf=False, cf_samps=None, cf_proj=False, + max_plt_len=None, force_length=None): + """plot_trajectory + + :param samps: Observed trajectory (output of format_dgen_samps) + :param pt_idx: Patient Index + :param cf: If true, plot distribution of counterfactuals + :param cf_samps: If cf, then these are the cf samples + :param cf_proj: Are these projected samples + :param max_plt_len: Maximum length to plot + :param force_length: Force length to a certain length + """ + this_df, outcome = df_from_samps(samps, pt_idx, get_outcome=True) + + eps = 0.5 + param_dict = { + 'Heart Rate': { + 'ticks': ['Low', 'Normal', 'High'], + 'vals': [0, 1, 2], + 'nrange': [0.75, 1.25], + 'plt_outcome': True, + 'ylabel': 'HR' + }, + 'SysBP': { + 'ticks': ['Low', 'Normal', 'High'], + 'vals': [0, 1, 2], + 'nrange': [0.75, 1.25], + 'plt_outcome': True, + 'ylabel': 'SysBP' + }, + 'Percent O2': { + 'ticks': ['Low', 'Normal'], + 'vals': [0, 1], + 'nrange': [0.75, 1.25], + 'plt_outcome': True, + 'ylabel': 'Pct O2' + }, + 'Glucose': { + 'ticks': ['V. Low', 'Low', 'Normal', 'High', 'V. High'], + 'vals': [0, 1, 2, 3, 4], + 'nrange': [1.75, 2.25], + 'plt_outcome': True, + 'ylabel': 'Glucose' + }, + 'Treat: AbX': { + 'ticks': ['Off', 'On'], + 'vals': [0, 1], + 'nrange': None, + 'plt_outcome': True, + 'ylabel': 'Tx: Abx' + }, + 'Treat: Vaso':{ + 'ticks': ['Off', 'On'], + 'vals': [0, 1], + 'nrange': None, + 'plt_outcome': True, + 'ylabel': 'Tx: Vaso' + }, + 'Treat: Vent': { + 'ticks': ['Off', 'On'], + 'vals': [0, 1], + 'nrange': None, + 'plt_outcome': True, + 'ylabel': 'Tx: Vent' + }, + } + + outcome_symbol = { + -1: { + 'marker': 'o', + 'color': 'r', + 'markersize': '10' + }, + 0: { + 'marker': 'o', + 'color': 'k', + 'markersize': '10' + }, + 1: { + 'marker': 'o', + 'color': 'g', + 'markersize': '10' + } + } + + fig, axes = plt.subplots(7, 1, sharex=True) + fig.set_size_inches(8, 10) + for i in range(7): + this_col = this_df.columns[i+1] + axes[i].plot(this_df['Time'], this_df[this_col], color='k') + + # Format the Y-axis according to the variable + axes[i].set_ylabel(param_dict[this_col]['ylabel']) + axes[i].set_yticks(param_dict[this_col]['vals']) + axes[i].set_yticklabels(param_dict[this_col]['ticks']) + axes[i].set_ylim(param_dict[this_col]['vals'][0] - eps, + param_dict[this_col]['vals'][-1]+ eps) + + # Plot the end of the sequence as red, green, black + if param_dict[this_col]['plt_outcome']: + obs_end_time = outcome[0] + end_event = outcome[1].astype(int) + axes[i].plot( + obs_end_time, + this_df[this_col][obs_end_time], + marker=outcome_symbol[end_event]['marker'], + color=outcome_symbol[end_event]['color'] + ) + + nrange = param_dict[this_col]['nrange'] + last_time = this_df.shape[0] + if force_length is not None: + last_time = force_length + if nrange is not None: + axes[i].hlines(nrange, xmin=0, xmax=last_time, + colors='r', + linestyles='dotted', label='Normal Range') + + axes[i].set_xlim(-0.25, last_time + 0.5) + # Format the X-axis as integers + axes[i].xaxis.set_ticks(np.arange(0, last_time + 1, 2)) + axes[i].xaxis.set_major_formatter(FormatStrFormatter('%d')) + + if cf: + if max_plt_len is None: + max_plt_len = this_df.shape[0] + 1 + assert cf_samps is not None + num_samps = cf_samps.shape[1] + for i in range(num_samps): + # import pdb + # pdb.set_trace() + this_df, outcome = \ + df_from_samps(cf_samps[:, i, :max_plt_len, :], + pt_idx, get_outcome=True, is_proj=cf_proj) + for i in range(7): + this_col = this_df.columns[i+1] + # No CF trajectory for glucose + if this_col == 'Glucose': + continue + axes[i].plot(this_df['Time'], this_df[this_col], alpha=0.1, color='b') + # Plot the end of the sequence as red, green, yellow + if param_dict[this_col]['plt_outcome']: + end_time = outcome[0] + end_event = outcome[1].astype(int) + axes[i].plot( + end_time, + this_df[this_col][end_time], + marker=outcome_symbol[end_event]['marker'], + color=outcome_symbol[end_event]['color'], + alpha=0.3 + ) + + return fig, axes + diff --git a/sepsisSim/data-prep/datagen-features-N=1e5-vaso_eps=0_1.ipynb b/sepsisSim/data-prep/datagen-features-N=1e5-vaso_eps=0_1.ipynb new file mode 100644 index 0000000..9a2b48e --- /dev/null +++ b/sepsisSim/data-prep/datagen-features-N=1e5-vaso_eps=0_1.ipynb @@ -0,0 +1,10334 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib\n", + "matplotlib.rcParams['font.sans-serif'] = ['Arial']\n", + "%config InlineBackend.figure_formats = ['svg']" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2019-04-19T16:43:04.753547Z", + "start_time": "2019-04-19T16:43:04.068986Z" + } + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import cf.counterfactual as cf\n", + "import cf.utils as utils\n", + "import pandas as pd\n", + "import pickle\n", + "import itertools as it\n", + "from tqdm import tqdm\n", + "import scipy.sparse\n", + "\n", + "# Sepsis Simulator code\n", + "from sepsisSimDiabetes.State import State\n", + "from sepsisSimDiabetes.Action import Action\n", + "from sepsisSimDiabetes.DataGenerator import DataGenerator\n", + "import sepsisSimDiabetes.MDP as simulator \n", + "\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import warnings\n", + "warnings.simplefilter(action='ignore', category=FutureWarning)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import joblib" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "output_dir = '../datagen/vaso_eps_0_1-100k/'\n", + "\n", + "import pathlib\n", + "pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2019-04-19T16:43:04.787255Z", + "start_time": "2019-04-19T16:43:04.770642Z" + } + }, + "outputs": [], + "source": [ + "NSIMSAMPS = 100_000 # Samples to draw from the simulator\n", + "NSTEPS = 20 # Max length of each trajectory\n", + "# DISCOUNT_Pol = 0.99 # Used for computing optimal policies\n", + "# DISCOUNT = 1 # Used for computing actual reward\n", + "\n", + "PROB_DIAB = 0.2\n", + "\n", + "# # Option 1: Use bootstrapping w/replacement on the original NSIMSAMPS to estimate errors\n", + "# USE_BOOSTRAP=True\n", + "# N_BOOTSTRAP = 100\n", + "\n", + "# # Option 2: Use repeated sampling (i.e., NSIMSAMPS fresh simulations each time) to get error bars; \n", + "# # This is done in the appendix of the paper, but not in the main paper\n", + "# N_REPEAT_SAMPLING = 1\n", + "\n", + "# # These are properties of the simulator, do not change\n", + "n_actions = Action.NUM_ACTIONS_TOTAL\n", + "# n_components = 2" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "optPol = joblib.load('../data/π_star.joblib')" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# vaso eps=0.1, mv abx optimal\n", + "behaviorPol = (np.tile(optPol.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)\n", + "behaviorPol[optPol == 1] = 0.9\n", + "behaviorPol[behaviorPol == 0.5] = 0.1" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0.9, 0.1, 0. , ..., 0. , 0. , 0. ],\n", + " [0.9, 0.1, 0. , ..., 0. , 0. , 0. ],\n", + " [0.9, 0.1, 0. , ..., 0. , 0. , 0. ],\n", + " ...,\n", + " [0.9, 0.1, 0. , ..., 0. , 0. , 0. ],\n", + " [0.9, 0.1, 0. , ..., 0. , 0. , 0. ],\n", + " [0.9, 0.1, 0. , ..., 0. , 0. , 0. ]])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "behaviorPol" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2019-04-19T17:12:19.355539Z", + "start_time": "2019-04-19T17:12:19.340377Z" + } + }, + "outputs": [], + "source": [ + "def conv_to_np(this_list):\n", + " this_arr = np.array(this_list)[:, np.newaxis]\n", + " # Make this idempotent\n", + " this_arr = this_arr.squeeze()[:, np.newaxis]\n", + " return this_arr" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Iteration: 1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Iteration: 2\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + } + ], + "source": [ + "# Generate 2 batches of data with N episodes each\n", + "for it in [1, 2]:\n", + " print('Iteration:', it, flush=True)\n", + " np.random.seed(it)\n", + " dgen = DataGenerator()\n", + " states, actions, lengths, rewards, diab, emp_tx_totals, emp_r_totals = dgen.simulate(\n", + " NSIMSAMPS, NSTEPS, policy=behaviorPol[:-2], policy_idx_type='full', output_state_idx_type='full',\n", + " p_diabetes=PROB_DIAB, modified=True, use_tqdm=True) #True, tqdm_desc='Behaviour Policy Simulation')\n", + "\n", + " obs_samps = utils.format_dgen_samps(\n", + " states, actions, rewards, diab, NSTEPS, NSIMSAMPS)\n", + "\n", + " df_samp_list = []\n", + " infos = []\n", + " for i in range(NSIMSAMPS):\n", + " pt_id = i + it*NSIMSAMPS\n", + " df_i, (len_i, y_i) = utils.df_from_samps(obs_samps, i, get_outcome=True, is_full=True)\n", + " df_i['pt_id'] = pt_id\n", + " df_i['Diabetic'] = diab[i][:len_i+1]\n", + " df_i['State_idx'] = states[i][:len_i+1]\n", + " df_i['Obs_idx'] = df_i['State_idx'].apply(lambda s: State(state_idx=s, idx_type='full').get_state_idx('obs'))\n", + " df_i['Proj_idx'] = df_i['State_idx'].apply(lambda s: State(state_idx=s, idx_type='full').get_state_idx('proj_obs'))\n", + " df_i['Action'] = actions[i][:len_i+1]\n", + " df_i['Reward'] = 0\n", + " df_i.loc[len_i, 'Reward'] = y_i\n", + " df_i = df_i.set_index('pt_id').reset_index()\n", + " df_samp_list.append(df_i)\n", + " infos.append([pt_id, diab[i][0], y_i, len_i, len(df_i)])\n", + "\n", + " df_samps = pd.concat(df_samp_list).astype(int)\n", + " df_features = pd.get_dummies(df_samps, columns=['Diabetic', 'Heart Rate', 'SysBP', 'Percent O2', 'Glucose', 'Treat: AbX', 'Treat: Vaso', 'Treat: Vent'])\n", + " df_samps_info = pd.DataFrame(infos, columns=['pt_id', 'Diabetic', 'Outcome', 'Steps', 'Length']).astype(int)\n", + " \n", + " assert df_samps.shape[1] == 15\n", + " assert df_features.shape[1] == 28\n", + " assert df_samps_info.shape[1] == 5\n", + " df_samps.to_csv('{}/{}-samples.csv'.format(output_dir, it), index=False)\n", + " df_features.to_csv('{}/{}-features.csv'.format(output_dir, it), index=False)\n", + " df_samps_info.to_csv('{}/{}-info.csv'.format(output_dir, it), index=False)\n", + " joblib.dump([states, actions, lengths, rewards, diab, emp_tx_totals, emp_r_totals], '{}/{}-alldata.joblib'.format(output_dir, it))\n", + " joblib.dump(obs_samps, '{}/{}-obs_samps.joblib'.format(output_dir, it))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2022-11-23T15:27:33.968062\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.bar(range(1440), df_samps.groupby('pt_id').first()['State_idx'].value_counts().reindex(range(1440)))\n", + "plt.xlabel('initial state index')\n", + "plt.ylabel('count')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Features" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from pandas import DataFrame\n", + "from tqdm import tqdm\n", + "from collections import defaultdict\n", + "import pickle\n", + "import itertools\n", + "import copy\n", + "import random\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "from tqdm import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.linear_model import LinearRegression\n", + "from sklearn.neural_network import MLPRegressor\n", + "from sklearn import metrics\n", + "import joblib\n", + "from joblib import Parallel, delayed" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "import scipy.sparse" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 21-dimensional state features" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "nS, nA = 1442, 8\n", + "d = 21" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "def get_state_action_feature(x_s, a):\n", + " x_sa = np.zeros((nA, d))\n", + " x_sa[a, :] = x_s\n", + " return x_sa.flatten()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "def make_features_single_trajectory(df_i):\n", + " # Initial timestep\n", + " s_init = df_i.iloc[0, 7:28].values\n", + " x_s_init = np.array(s_init)\n", + " xa_s_init_all = np.array([get_state_action_feature(x_s_init, a_) for a_ in range(nA)])\n", + "\n", + " # Intermediate timestep\n", + " if len(df_i) > 1:\n", + " s = df_i.iloc[:-1, 7:28].values\n", + " a = df_i.iloc[:-1]['Action'].values\n", + " r = df_i.iloc[:-1]['Reward'].values\n", + " s_next = df_i.iloc[1:, 7:28].values\n", + "\n", + " n = len(s)\n", + " x_s = np.array(s)\n", + " xa_sa = np.array([get_state_action_feature(x_s[j, :], a[j]) for j in range(n)])\n", + "\n", + " x_s_next = np.array(s_next)\n", + " xa_s_next_all = np.vstack([\n", + " np.vstack([get_state_action_feature(x_s_next[j], a_) for a_ in range(nA)]) \n", + " for j in range(n)\n", + " ])\n", + " else:\n", + " x_s = np.array((0, d))\n", + " a = np.zeros((0), dtype=int)\n", + " xa_sa = np.array((0, d*nA))\n", + " r = np.zeros((0))\n", + " x_s_next = np.array((0, d))\n", + " xa_s_next_all = np.array((0, d*nA))\n", + "\n", + " # Final timestep\n", + " s_last = df_i.iloc[-1, 7:28].values\n", + " a_last = df_i.iloc[-1]['Action']\n", + " r_last = df_i.iloc[-1]['Reward']\n", + " if r_last == -1 or r_last == 1:\n", + " # Reached death/disch states\n", + " # every action leads to reward\n", + " x_s_last = np.array(s_last)\n", + " xa_s_last_all = np.array([get_state_action_feature(x_s_last, a_) for a_ in range(nA)])\n", + " r_last_all = np.array(nA * [r_last])\n", + "\n", + " xa_out = np.vstack([xa_sa, xa_s_last_all])\n", + " xa_next_out = np.vstack([xa_s_next_all, np.zeros((nA*nA, nA*d))])\n", + " r_out = np.concatenate([r, r_last_all])\n", + "\n", + " a_out = np.concatenate([a, (list(range(nA)))])\n", + " x_out = np.vstack([x_s, *(nA*[x_s_last])])\n", + " x_next_out = np.vstack([x_s_next, np.zeros((nA, d))])\n", + " else: \n", + " # terminated early due to max length, so no next state information\n", + " xa_out = xa_sa\n", + " xa_next_out = xa_s_next_all\n", + " r_out = r\n", + "\n", + " x_out = x_s\n", + " a_out = a\n", + " x_next_out = x_s_next\n", + " \n", + " return x_s_init, xa_s_init_all, x_out, a_out, xa_out, r_out, x_next_out, xa_next_out" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 100000/100000 [03:17<00:00, 505.85it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(800000, 168) (1467199, 168) (11737592, 168) (1467199,)\n", + "(100000, 21) (1467199, 21) (1467199,) (1467199,) (1467199, 21)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 100000/100000 [03:28<00:00, 480.27it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(800000, 168) (1471272, 168) (11770176, 168) (1471272,)\n", + "(100000, 21) (1471272, 21) (1471272,) (1471272,) (1471272, 21)\n" + ] + } + ], + "source": [ + "for it in [1,2]:\n", + " df_features = pd.read_csv('{}/{}-features.csv'.format(output_dir, it))\n", + " out = [make_features_single_trajectory(df_i) for i, df_i in tqdm(df_features.groupby('pt_id'))]\n", + " X_init, Xa_init, X, A, Xa, R, X_next, Xa_next = zip(*out)\n", + " X_init = np.vstack(X_init)\n", + " Xa_init = np.vstack(Xa_init)\n", + " X = np.vstack(X)\n", + " Xa = np.vstack(Xa)\n", + " A = np.concatenate(A)\n", + " R = np.concatenate(R)\n", + " X_next = np.vstack(X_next)\n", + " Xa_next = np.vstack(Xa_next)\n", + " print(Xa_init.shape, Xa.shape, Xa_next.shape, R.shape)\n", + " print(X_init.shape, X.shape, A.shape, R.shape, X_next.shape)\n", + "\n", + " # Store indices of beginning of each episode\n", + " lengths = [len(x_i) for x_i in list(zip(*out))[2]]\n", + " inds_init = np.cumsum([0] + lengths)\n", + "\n", + " joblib.dump({\n", + " 'X_init': X_init, 'X': X, 'A': A, 'R': R, 'X_next': X_next, \n", + " 'Xa_init': Xa_init, 'Xa': Xa, 'Xa_next': Xa_next,\n", + " 'lengths': lengths, 'inds_init': inds_init,\n", + " }, '{}/{}-21d-feature-matrices.joblib'.format(output_dir, it))\n", + "\n", + " joblib.dump({\n", + " 'X_init': scipy.sparse.csr_matrix(X_init), 'X': scipy.sparse.csr_matrix(X), 'A': A, 'R': R, 'X_next': scipy.sparse.csr_matrix(X_next), \n", + " 'Xa_init': scipy.sparse.csr_matrix(Xa_init), 'Xa': scipy.sparse.csr_matrix(Xa), 'Xa_next': scipy.sparse.csr_matrix(Xa_next),\n", + " 'lengths': lengths, 'inds_init': inds_init,\n", + " }, '{}/{}-21d-feature-matrices.sparse.joblib'.format(output_dir, it))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "RL_venv", + "language": "python", + "name": "rl_venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/sepsisSim/data-prep/sepsisSimDiabetes/Action.py b/sepsisSim/data-prep/sepsisSimDiabetes/Action.py new file mode 100644 index 0000000..65939ef --- /dev/null +++ b/sepsisSim/data-prep/sepsisSimDiabetes/Action.py @@ -0,0 +1,81 @@ +import numpy as np + +class Action(object): + + NUM_ACTIONS_TOTAL = 8 + ANTIBIOTIC_STRING = "antibiotic" + VENT_STRING = "ventilation" + VASO_STRING = "vasopressors" + ACTION_VEC_SIZE = 3 + + def __init__(self, selected_actions = None, action_idx = None): + assert (selected_actions is not None and action_idx is None) \ + or (selected_actions is None and action_idx is not None), \ + "must specify either set of action strings or action index" + if selected_actions is not None: + if Action.ANTIBIOTIC_STRING in selected_actions: + self.antibiotic = 1 + else: + self.antibiotic = 0 + if Action.VENT_STRING in selected_actions: + self.ventilation = 1 + else: + self.ventilation = 0 + if Action.VASO_STRING in selected_actions: + self.vasopressors = 1 + else: + self.vasopressors = 0 + else: + mod_idx = action_idx + term_base = Action.NUM_ACTIONS_TOTAL/2 + self.antibiotic = np.floor(mod_idx/term_base).astype(int) + mod_idx %= term_base + term_base /= 2 + self.ventilation = np.floor(mod_idx/term_base).astype(int) + mod_idx %= term_base + term_base /= 2 + self.vasopressors = np.floor(mod_idx/term_base).astype(int) + + def __eq__(self, other): + return isinstance(other, self.__class__) and \ + self.antibiotic == other.antibiotic and \ + self.ventilation == other.ventilation and \ + self.vasopressors == other.vasopressors + + def __ne__(self, other): + return not self.__eq__(other) + + def get_action_idx(self): + assert self.antibiotic in (0, 1) + assert self.ventilation in (0, 1) + assert self.vasopressors in (0, 1) + return 4*self.antibiotic + 2*self.ventilation + self.vasopressors + + def __hash__(self): + return self.get_action_idx() + + def get_selected_actions(self): + selected_actions = set() + if self.antibiotic == 1: + selected_actions.add(Action.ANTIBIOTIC_STRING) + if self.ventilation == 1: + selected_actions.add(Action.VENT_STRING) + if self.vasopressors == 1: + selected_actions.add(Action.VASO_STRING) + return selected_actions + + def get_abbrev_string(self): + ''' + AEV: antibiotics, ventilation, vasopressors + ''' + output_str = '' + if self.antibiotic == 1: + output_str += 'A' + if self.ventilation == 1: + output_str += 'E' + if self.vasopressors == 1: + output_str += 'V' + return output_str + + def get_action_vec(self): + return np.array([[self.antibiotic], [self.ventilation], [self.vasopressors]]) diff --git a/sepsisSim/data-prep/sepsisSimDiabetes/DataGenerator.py b/sepsisSim/data-prep/sepsisSimDiabetes/DataGenerator.py new file mode 100644 index 0000000..778f361 --- /dev/null +++ b/sepsisSim/data-prep/sepsisSimDiabetes/DataGenerator.py @@ -0,0 +1,95 @@ +import numpy as np, random +from .MDP import MDP +from .State import State +from .Action import Action +# from tqdm import tqdm_notebook as tqdm +# from tqdm.notebook import tqdm +from tqdm import tqdm + +''' +Simulates data generation from an MDP +''' +class DataGenerator(object): + + def select_actions(self, state, policy): + ''' + select action for state from policy + if unspecified, a random action is returned + ''' + if state not in policy: + return Action(action_idx = np.random.randint(8)) + return policy[state] + + def simulate(self, num_iters, max_num_steps, + policy=None, policy_idx_type='full', p_diabetes=0.2, modified=False, + output_state_idx_type='obs', use_tqdm=False, tqdm_desc=''): + ''' + policy is an array of probabilities + ''' + assert policy is not None, "Please specify a policy" + + # Set the default value of states / actions to negative -1, + # corresponding to None + iter_states = np.ones((num_iters, max_num_steps+1, 1), dtype=int)*(-1) + iter_actions = np.ones((num_iters, max_num_steps, 1), dtype=int)*(-1) + iter_rewards = np.zeros((num_iters, max_num_steps, 1)) + iter_lengths = np.zeros((num_iters, 1), dtype=int) + + # Record diabetes, the hidden mixture component + iter_component = np.zeros((num_iters, max_num_steps, 1), dtype=int) + mdp = MDP(init_state_idx=None, # Random initial state + policy_array=policy, policy_idx_type=policy_idx_type, + p_diabetes=p_diabetes) + + # Empirical transition / reward matrix + if output_state_idx_type == 'obs': + emp_tx_mat = np.zeros((Action.NUM_ACTIONS_TOTAL, + State.NUM_OBS_STATES, State.NUM_OBS_STATES)) + emp_r_mat = np.zeros((Action.NUM_ACTIONS_TOTAL, + State.NUM_OBS_STATES, State.NUM_OBS_STATES)) + elif output_state_idx_type == 'full': + emp_tx_mat = np.zeros((Action.NUM_ACTIONS_TOTAL, + State.NUM_FULL_STATES, State.NUM_FULL_STATES)) + emp_r_mat = np.zeros((Action.NUM_ACTIONS_TOTAL, + State.NUM_FULL_STATES, State.NUM_FULL_STATES)) + else: + raise NotImplementedError() + + for itr in tqdm(range(num_iters), disable=not(use_tqdm), desc=tqdm_desc, leave=False): + # MDP will generate the diabetes index as well + mdp.state = mdp.get_new_state(modified=modified) + this_diabetic_idx = mdp.state.diabetic_idx + iter_component[itr, :] = this_diabetic_idx # Never changes + iter_states[itr, 0, 0] = mdp.state.get_state_idx( + idx_type=output_state_idx_type) + for step in range(max_num_steps): + assert not mdp.state.check_absorbing_state() + step_action = mdp.select_actions() + + this_action_idx = step_action.get_action_idx().astype(int) + this_from_state_idx = mdp.state.get_state_idx( + idx_type=output_state_idx_type).astype(int) + + # Take the action, new state is property of the MDP + step_reward = mdp.transition(step_action) + this_to_state_idx = mdp.state.get_state_idx( + idx_type=output_state_idx_type).astype(int) + + iter_actions[itr, step, 0] = this_action_idx + iter_states[itr, step+1, 0] = this_to_state_idx + + # Record in transition matrix + emp_tx_mat[this_action_idx, + this_from_state_idx, this_to_state_idx] += 1 + emp_r_mat[this_action_idx, + this_from_state_idx, this_to_state_idx] += step_reward + + if step_reward != 0: + iter_rewards[itr, step, 0] = step_reward + iter_lengths[itr, 0] = step+1 + break + + if step == max_num_steps-1: + iter_lengths[itr, 0] = max_num_steps + + return iter_states, iter_actions, iter_lengths, iter_rewards, iter_component, emp_tx_mat, emp_r_mat diff --git a/sepsisSim/data-prep/sepsisSimDiabetes/MDP.py b/sepsisSim/data-prep/sepsisSimDiabetes/MDP.py new file mode 100644 index 0000000..88df322 --- /dev/null +++ b/sepsisSim/data-prep/sepsisSimDiabetes/MDP.py @@ -0,0 +1,337 @@ +import numpy as np +from .State import State +from .Action import Action + +''' +Includes blood glucose level proxy for diabetes: 0-3 + (lo2, lo1, normal, hi1, hi2); Any other than normal is "abnormal" +Initial distribution: + [.05, .15, .6, .15, .05] for non-diabetics and [.01, .05, .15, .6, .19] for diabetics + +Effect of vasopressors on if diabetic: + raise blood pressure: normal -> hi w.p. .9, lo -> normal w.p. .5, lo -> hi w.p. .4 + raise blood glucose by 1 w.p. .5 + +Effect of vasopressors off if diabetic: + blood pressure falls by 1 w.p. .05 instead of .1 + glucose does not fall - apply fluctuations below instead + +Fluctuation in blood glucose levels (IV/insulin therapy are not possible actions): + fluctuate w.p. .3 if diabetic + fluctuate w.p. .1 if non-diabetic +Ref: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4530321/ + +Additional fluctuation regardless of other changes +This order is applied: + antibiotics, ventilation, vasopressors, fluctuations +''' + +class MDP(object): + + def __init__(self, init_state_idx=None, init_state_idx_type='obs', + policy_array=None, policy_idx_type='obs', p_diabetes=0.2, modified=False): + ''' + initialize the simulator + ''' + assert p_diabetes >= 0 and p_diabetes <= 1, \ + "Invalid p_diabetes: {}".format(p_diabetes) + assert policy_idx_type in ['obs', 'full', 'proj_obs'] + + # Check the policy dimensions (states x actions) + if policy_array is not None: + assert policy_array.shape[1] == Action.NUM_ACTIONS_TOTAL + if policy_idx_type == 'obs': + assert policy_array.shape[0] == State.NUM_OBS_STATES + elif policy_idx_type == 'full': + assert policy_array.shape[0] == \ + State.NUM_HID_STATES * State.NUM_OBS_STATES + elif policy_idx_type == 'proj_obs': + assert policy_array.shape[0] == State.NUM_PROJ_OBS_STATES + + # p_diabetes is used to generate random state if init_state is None + self.p_diabetes = p_diabetes + self.state = None + + # Only need to use init_state_idx_type if you are providing a state_idx! + self.state = self.get_new_state(init_state_idx, init_state_idx_type, modified=False) + + self.policy_array = policy_array + self.policy_idx_type = policy_idx_type # Used for mapping the policy to actions + + def get_new_state(self, state_idx = None, idx_type = 'obs', diabetic_idx = None, modified = False): + ''' + use to start MDP over. A few options: + + Full specification: + 1. Provide state_idx with idx_type = 'obs' + diabetic_idx + 2. Provide state_idx with idx_type = 'full', diabetic_idx is ignored + 3. Provide state_idx with idx_type = 'proj_obs' + diabetic_idx* + + * This option will set glucose to a normal level + + Random specification + 4. State_idx, no diabetic_idx: Latter will be generated + 5. No state_idx, no diabetic_idx: Completely random + 6. No state_idx, diabetic_idx given: Random conditional on diabetes + ''' + assert idx_type in ['obs', 'full', 'proj_obs'] + option = None + if state_idx is not None: + if idx_type == 'obs' and diabetic_idx is not None: + option = 'spec_obs' + elif idx_type == 'obs' and diabetic_idx is None: + option = 'spec_obs_no_diab' + diabetic_idx = np.random.binomial(1, self.p_diabetes) + elif idx_type == 'full': + option = 'spec_full' + elif idx_type == 'proj_obs' and diabetic_idx is not None: + option = 'spec_proj_obs' + elif state_idx is None and diabetic_idx is None: + option = 'random' + elif state_idx is None and diabetic_idx is not None: + option = 'random_cond_diab' + + assert option is not None, "Invalid specification of new state" + + if option in ['random', 'random_cond_diab'] : + init_state = self.generate_random_state(diabetic_idx, modified) + # Do not start in death or discharge state + while init_state.check_absorbing_state(): + init_state = self.generate_random_state(diabetic_idx, modified) + else: + # Note that diabetic_idx will be ignored if idx_type = 'full' + init_state = State( + state_idx=state_idx, idx_type=idx_type, + diabetic_idx=diabetic_idx) + + return init_state + + def generate_random_state(self, diabetic_idx=None, modified=False): + import joblib + if not modified: + isd = joblib.load('../data/prior_initial_state_absorbing.joblib') + else: + isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib') + isd = (isd > 0) + isd = isd / isd.sum() + + # Note that we will condition on diabetic idx if provided + if diabetic_idx is None: + diabetic_idx = np.random.binomial(1, self.p_diabetes) + + if diabetic_idx == 0: + prior = isd[:720] + prior = prior / prior.sum() + s = np.random.choice(np.arange(720), p=prior) + else: + prior = isd[720:1440] + prior = prior / prior.sum() + s = 720 + np.random.choice(np.arange(720), p=prior) + + return State(state_idx=s, idx_type='full') + + def generate_random_state_OLD(self, diabetic_idx=None): + # Note that we will condition on diabetic idx if provided + if diabetic_idx is None: + diabetic_idx = np.random.binomial(1, self.p_diabetes) + + # hr and sys_bp w.p. [.25, .5, .25] + hr_state = np.random.choice(np.arange(3), p=np.array([.25, .5, .25])) + sysbp_state = np.random.choice(np.arange(3), p=np.array([.25, .5, .25])) + # percoxyg w.p. [.2, .8] + percoxyg_state = np.random.choice(np.arange(2), p=np.array([.2, .8])) + + if diabetic_idx == 0: + glucose_state = np.random.choice(np.arange(5), \ + p=np.array([.05, .15, .6, .15, .05])) + else: + glucose_state = np.random.choice(np.arange(5), \ + p=np.array([.01, .05, .15, .6, .19])) + antibiotic_state = 0 + vaso_state = 0 + vent_state = 0 + + state_categs = [hr_state, sysbp_state, percoxyg_state, + glucose_state, antibiotic_state, vaso_state, vent_state] + + return State(state_categs=state_categs, diabetic_idx=diabetic_idx) + + def transition_antibiotics_on(self): + ''' + antibiotics state on + heart rate, sys bp: hi -> normal w.p. .5 + ''' + self.state.antibiotic_state = 1 + if self.state.hr_state == 2 and np.random.uniform(0,1) < 0.5: + self.state.hr_state = 1 + if self.state.sysbp_state == 2 and np.random.uniform(0,1) < 0.5: + self.state.sysbp_state = 1 + + def transition_antibiotics_off(self): + ''' + antibiotics state off + if antibiotics was on: heart rate, sys bp: normal -> hi w.p. .1 + ''' + if self.state.antibiotic_state == 1: + if self.state.hr_state == 1 and np.random.uniform(0,1) < 0.1: + self.state.hr_state = 2 + if self.state.sysbp_state == 1 and np.random.uniform(0,1) < 0.1: + self.state.sysbp_state = 2 + self.state.antibiotic_state = 0 + + def transition_vent_on(self): + ''' + ventilation state on + percent oxygen: low -> normal w.p. .7 + ''' + self.state.vent_state = 1 + if self.state.percoxyg_state == 0 and np.random.uniform(0,1) < 0.7: + self.state.percoxyg_state = 1 + + def transition_vent_off(self): + ''' + ventilation state off + if ventilation was on: percent oxygen: normal -> lo w.p. .1 + ''' + if self.state.vent_state == 1: + if self.state.percoxyg_state == 1 and np.random.uniform(0,1) < 0.1: + self.state.percoxyg_state = 0 + self.state.vent_state = 0 + + def transition_vaso_on(self): + ''' + vasopressor state on + for non-diabetic: + sys bp: low -> normal, normal -> hi w.p. .7 + for diabetic: + raise blood pressure: normal -> hi w.p. .9, + lo -> normal w.p. .5, lo -> hi w.p. .4 + raise blood glucose by 1 w.p. .5 + ''' + self.state.vaso_state = 1 + if self.state.diabetic_idx == 0: + if np.random.uniform(0,1) < 0.7: + if self.state.sysbp_state == 0: + self.state.sysbp_state = 1 + elif self.state.sysbp_state == 1: + self.state.sysbp_state = 2 + else: + if self.state.sysbp_state == 1: + if np.random.uniform(0,1) < 0.9: + self.state.sysbp_state = 2 + elif self.state.sysbp_state == 0: + up_prob = np.random.uniform(0,1) + if up_prob < 0.5: + self.state.sysbp_state = 1 + elif up_prob < 0.9: + self.state.sysbp_state = 2 + if np.random.uniform(0,1) < 0.5: + self.state.glucose_state = min(4, self.state.glucose_state + 1) + + def transition_vaso_off(self): + ''' + vasopressor state off + if vasopressor was on: + for non-diabetics, sys bp: normal -> low, hi -> normal w.p. .1 + for diabetics, blood pressure falls by 1 w.p. .05 instead of .1 + ''' + if self.state.vaso_state == 1: + if self.state.diabetic_idx == 0: + if np.random.uniform(0,1) < 0.1: + self.state.sysbp_state = max(0, self.state.sysbp_state - 1) + else: + if np.random.uniform(0,1) < 0.05: + self.state.sysbp_state = max(0, self.state.sysbp_state - 1) + self.state.vaso_state = 0 + + def transition_fluctuate(self, hr_fluctuate, sysbp_fluctuate, percoxyg_fluctuate, \ + glucose_fluctuate): + ''' + all (non-treatment) states fluctuate +/- 1 w.p. .1 + exception: glucose flucuates +/- 1 w.p. .3 if diabetic + ''' + if hr_fluctuate: + hr_prob = np.random.uniform(0,1) + if hr_prob < 0.1: + self.state.hr_state = max(0, self.state.hr_state - 1) + elif hr_prob < 0.2: + self.state.hr_state = min(2, self.state.hr_state + 1) + if sysbp_fluctuate: + sysbp_prob = np.random.uniform(0,1) + if sysbp_prob < 0.1: + self.state.sysbp_state = max(0, self.state.sysbp_state - 1) + elif sysbp_prob < 0.2: + self.state.sysbp_state = min(2, self.state.sysbp_state + 1) + if percoxyg_fluctuate: + percoxyg_prob = np.random.uniform(0,1) + if percoxyg_prob < 0.1: + self.state.percoxyg_state = max(0, self.state.percoxyg_state - 1) + elif percoxyg_prob < 0.2: + self.state.percoxyg_state = min(1, self.state.percoxyg_state + 1) + if glucose_fluctuate: + glucose_prob = np.random.uniform(0,1) + if self.state.diabetic_idx == 0: + if glucose_prob < 0.1: + self.state.glucose_state = max(0, self.state.glucose_state - 1) + elif glucose_prob < 0.2: + self.state.glucose_state = min(4, self.state.glucose_state + 1) + else: + if glucose_prob < 0.3: + self.state.glucose_state = max(0, self.state.glucose_state - 1) + elif glucose_prob < 0.6: + self.state.glucose_state = min(4, self.state.glucose_state + 1) + + def calculateReward(self): + num_abnormal = self.state.get_num_abnormal() + if num_abnormal >= 3: + return -1 + elif num_abnormal == 0 and not self.state.on_treatment(): + return 1 + return 0 + + def transition(self, action): + self.state = self.state.copy_state() + + if action.antibiotic == 1: + self.transition_antibiotics_on() + hr_fluctuate = False + sysbp_fluctuate = False + elif self.state.antibiotic_state == 1: + self.transition_antibiotics_off() + hr_fluctuate = False + sysbp_fluctuate = False + else: + hr_fluctuate = True + sysbp_fluctuate = True + + if action.ventilation == 1: + self.transition_vent_on() + percoxyg_fluctuate = False + elif self.state.vent_state == 1: + self.transition_vent_off() + percoxyg_fluctuate = False + else: + percoxyg_fluctuate = True + + glucose_fluctuate = True + + if action.vasopressors == 1: + self.transition_vaso_on() + sysbp_fluctuate = False + glucose_fluctuate = False + elif self.state.vaso_state == 1: + self.transition_vaso_off() + sysbp_fluctuate = False + + self.transition_fluctuate(hr_fluctuate, sysbp_fluctuate, percoxyg_fluctuate, \ + glucose_fluctuate) + + return self.calculateReward() + + def select_actions(self): + assert self.policy_array is not None + probs = self.policy_array[ + self.state.get_state_idx(self.policy_idx_type) + ] + aev_idx = np.random.choice(np.arange(Action.NUM_ACTIONS_TOTAL), p=probs) + return Action(action_idx = aev_idx) diff --git a/sepsisSim/data-prep/sepsisSimDiabetes/State.py b/sepsisSim/data-prep/sepsisSimDiabetes/State.py new file mode 100644 index 0000000..cd5cec2 --- /dev/null +++ b/sepsisSim/data-prep/sepsisSimDiabetes/State.py @@ -0,0 +1,229 @@ +import numpy as np + +''' +Includes blood glucose level proxy for diabetes: 0-3 + (lo2 - counts as abnormal, lo1, normal, hi1, hi2 - counts as abnormal) +Initial distribution: + [.05, .15, .6, .15, .05] for non-diabetics and [.01, .05, .15, .6, .19] for diabetics +''' + +class State(object): + + NUM_OBS_STATES = 720 + NUM_HID_STATES = 2 # Binary value of diabetes + NUM_PROJ_OBS_STATES = int(720 / 5) # Marginalizing over glucose + NUM_FULL_STATES = int(NUM_OBS_STATES * NUM_HID_STATES) + + def __init__(self, + state_idx = None, idx_type = 'obs', + diabetic_idx = None, state_categs = None): + + assert state_idx is not None or state_categs is not None + assert ((diabetic_idx is not None and diabetic_idx in [0, 1]) or + (state_idx is not None and idx_type == 'full')) + + assert idx_type in ['obs', 'full', 'proj_obs'] + + if state_idx is not None: + self.set_state_by_idx( + state_idx, idx_type=idx_type, diabetic_idx=diabetic_idx) + elif state_categs is not None: + assert len(state_categs) == 7, "must specify 7 state variables" + self.hr_state = state_categs[0] + self.sysbp_state = state_categs[1] + self.percoxyg_state = state_categs[2] + self.glucose_state = state_categs[3] + self.antibiotic_state = state_categs[4] + self.vaso_state = state_categs[5] + self.vent_state = state_categs[6] + self.diabetic_idx = diabetic_idx + + def check_absorbing_state(self): + num_abnormal = self.get_num_abnormal() + if num_abnormal >= 3: + return True + elif num_abnormal == 0 and not self.on_treatment(): + return True + return False + + def set_state_by_idx(self, state_idx, idx_type, diabetic_idx=None): + """set_state_by_idx + + The state index is determined by using "bit" arithmetic, with the + complication that not every state is binary + + :param state_idx: Given index + :param idx_type: Index type, either observed (720), projected (144) or + full (1440) + :param diabetic_idx: If full state index not given, this is required + """ + if idx_type == 'obs': + term_base = State.NUM_OBS_STATES/3 # Starts with heart rate + elif idx_type == 'proj_obs': + term_base = State.NUM_PROJ_OBS_STATES/3 + elif idx_type == 'full': + term_base = State.NUM_FULL_STATES/2 # Starts with diab + + # Start with the given state index + mod_idx = state_idx + + if idx_type == 'full': + self.diabetic_idx = np.floor(mod_idx/term_base).astype(int) + mod_idx %= term_base + term_base /= 3 # This is for heart rate, the next item + else: + assert diabetic_idx is not None + self.diabetic_idx = diabetic_idx + + self.hr_state = np.floor(mod_idx/term_base).astype(int) + + mod_idx %= term_base + term_base /= 3 + self.sysbp_state = np.floor(mod_idx/term_base).astype(int) + + mod_idx %= term_base + term_base /= 2 + self.percoxyg_state = np.floor(mod_idx/term_base).astype(int) + + if idx_type == 'proj_obs': + self.glucose_state = 2 + else: + mod_idx %= term_base + term_base /= 5 + self.glucose_state = np.floor(mod_idx/term_base).astype(int) + + mod_idx %= term_base + term_base /= 2 + self.antibiotic_state = np.floor(mod_idx/term_base).astype(int) + + mod_idx %= term_base + term_base /= 2 + self.vaso_state = np.floor(mod_idx/term_base).astype(int) + + mod_idx %= term_base + term_base /= 2 + self.vent_state = np.floor(mod_idx/term_base).astype(int) + + def get_state_idx(self, idx_type='obs'): + ''' + returns integer index of state: significance order as in categorical array + ''' + if idx_type == 'obs': + categ_num = np.array([3,3,2,5,2,2,2]) + state_categs = [ + self.hr_state, + self.sysbp_state, + self.percoxyg_state, + self.glucose_state, + self.antibiotic_state, + self.vaso_state, + self.vent_state] + elif idx_type == 'proj_obs': + categ_num = np.array([3,3,2,2,2,2]) + state_categs = [ + self.hr_state, + self.sysbp_state, + self.percoxyg_state, + self.antibiotic_state, + self.vaso_state, + self.vent_state] + elif idx_type == 'full': + categ_num = np.array([2,3,3,2,5,2,2,2]) + state_categs = [ + self.diabetic_idx, + self.hr_state, + self.sysbp_state, + self.percoxyg_state, + self.glucose_state, + self.antibiotic_state, + self.vaso_state, + self.vent_state] + + sum_idx = 0 + prev_base = 1 + for i in range(len(state_categs)): + idx = len(state_categs) - 1 - i + sum_idx += prev_base*state_categs[idx] + prev_base *= categ_num[idx] + return sum_idx + + def __eq__(self, other): + ''' + override equals: two states equal if all internal states same + ''' + return isinstance(other, self.__class__) and \ + self.hr_state == other.hr_state and \ + self.sysbp_state == other.sysbp_state and \ + self.percoxyg_state == other.percoxyg_state and \ + self.glucose_state == other.glucose_state and \ + self.antibiotic_state == other.antibiotic_state and \ + self.vaso_state == other.vaso_state and \ + self.vent_state == other.vent_state + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return self.get_state_idx() + + def get_num_abnormal(self): + ''' + returns number of abnormal conditions + ''' + num_abnormal = 0 + if self.hr_state != 1: + num_abnormal += 1 + if self.sysbp_state != 1: + num_abnormal += 1 + if self.percoxyg_state != 1: + num_abnormal += 1 + if self.glucose_state != 2: + num_abnormal += 1 + return num_abnormal + + def on_treatment(self): + ''' + returns True iff any of 3 treatments active + ''' + if self.antibiotic_state == 0 and \ + self.vaso_state == 0 and self.vent_state == 0: + return False + return True + + def on_antibiotics(self): + ''' + returns True iff antibiotics active + ''' + return self.antibiotic_state == 1 + + def on_vasopressors(self): + ''' + returns True iff vasopressors active + ''' + return self.vaso_state == 1 + + def on_ventilation(self): + ''' + returns True iff ventilation active + ''' + return self.vent_state == 1 + + def copy_state(self): + return State(state_categs = [ + self.hr_state, + self.sysbp_state, + self.percoxyg_state, + self.glucose_state, + self.antibiotic_state, + self.vaso_state, + self.vent_state], + diabetic_idx=self.diabetic_idx) + + def get_state_vector(self): + return np.array([self.hr_state, + self.sysbp_state, + self.percoxyg_state, + self.glucose_state, + self.antibiotic_state, + self.vaso_state, + self.vent_state]).astype(int) diff --git a/sepsisSim/data-prep/sepsisSimDiabetes/__init__.py b/sepsisSim/data-prep/sepsisSimDiabetes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sepsisSim/experiments/0-prepopulate_annotations.ipynb b/sepsisSim/experiments/0-prepopulate_annotations.ipynb new file mode 100644 index 0000000..c23a1a9 --- /dev/null +++ b/sepsisSim/experiments/0-prepopulate_annotations.ipynb @@ -0,0 +1,873 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6a6b03ce-53e5-4397-bdc8-0cd7dac74198", + "metadata": {}, + "source": [ + "Configurations\n", + "- dataset: vaso_eps_0_1\n", + "- annotation function: annotOpt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "01f0175c-3e58-4a93-807a-b7ffa5a00531", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib\n", + "matplotlib.rcParams['font.sans-serif'] = ['FreeSans']\n", + "%config InlineBackend.figure_formats = ['svg']" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "62d6125a-5935-4e5b-80ff-4891f32ec784", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "from collections import defaultdict\n", + "import pickle\n", + "import itertools\n", + "import copy\n", + "import random\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import scipy.stats\n", + "from sklearn import metrics\n", + "import itertools\n", + "\n", + "import joblib\n", + "from joblib import Parallel, delayed" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8c4244f8-4324-494a-ae91-284f6e674169", + "metadata": {}, + "outputs": [], + "source": [ + "from OPE_utils_new import (\n", + " format_data_tensor,\n", + " policy_eval_analytic_finite,\n", + " OPE_IS_h,\n", + " compute_behavior_policy_h,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "81549d03-e43f-42fa-882e-56a87265835c", + "metadata": {}, + "outputs": [], + "source": [ + "NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP\n", + "G_min = -1 # the minimum possible return\n", + "G_max = 1 # the maximum possible return\n", + "nS, nA = 1442, 8\n", + "\n", + "PROB_DIAB = 0.2" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "964736ac-8bcc-45ee-820d-5964e9df807b", + "metadata": {}, + "outputs": [], + "source": [ + "# Ground truth MDP model\n", + "MDP_parameters = joblib.load('../data/MDP_parameters.joblib')\n", + "P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next)\n", + "R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A)\n", + "nS, nA = R.shape\n", + "gamma = 0.99\n", + "\n", + "# unif rand isd, mixture of diabetic state\n", + "isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib')\n", + "isd = (isd > 0).astype(float)\n", + "isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB)\n", + "isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "4287d548-043d-4f37-ae4c-ed92513f451f", + "metadata": {}, + "outputs": [], + "source": [ + "# Precomputed optimal policy\n", + "π_star = joblib.load('../data/π_star.joblib')" + ] + }, + { + "cell_type": "markdown", + "id": "2eb8e706-586b-4cb9-9df6-913e8e1763fc", + "metadata": {}, + "source": [ + "## Load data" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "eaf5b9bc-86c1-49e2-a5a8-81f87458d437", + "metadata": {}, + "outputs": [], + "source": [ + "input_dir = '../datagen/vaso_eps_0_1-100k/'" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "51a1bcf1-f9f1-40e4-9342-ea8b1ff587f5", + "metadata": {}, + "outputs": [], + "source": [ + "def load_data(fname):\n", + " print('Loading data', fname, '...', end='')\n", + " df_data = pd.read_csv('{}/{}'.format(input_dir, fname)).rename(columns={'State_idx': 'State'})#[['pt_id', 'Time', 'State', 'Action', 'Reward']]\n", + "\n", + " # Assign next state\n", + " df_data['NextState'] = [*df_data['State'].iloc[1:].values, -1]\n", + " df_data.loc[df_data.groupby('pt_id')['Time'].idxmax(), 'NextState'] = -1\n", + " df_data.loc[(df_data['Reward'] == -1), 'NextState'] = 1440\n", + " df_data.loc[(df_data['Reward'] == 1), 'NextState'] = 1441\n", + "\n", + " assert ((df_data['Reward'] != 0) == (df_data['Action'] == -1)).all()\n", + "\n", + " print('DONE')\n", + " return df_data" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "5ee3d16f-24d5-4121-b5ec-2bbf06f74c73", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading data 1-features.csv ...DONE\n", + "Loading data 2-features.csv ...DONE\n" + ] + } + ], + "source": [ + "df_seed1 = load_data('1-features.csv') # tr\n", + "df_seed2 = load_data('2-features.csv') # va" + ] + }, + { + "cell_type": "markdown", + "id": "c81658e9-5862-4670-a46a-b0978ce4bbcc", + "metadata": {}, + "source": [ + "## Policies" + ] + }, + { + "cell_type": "markdown", + "id": "a0ecf4c7-c8f3-4528-8c20-2db62f14126f", + "metadata": {}, + "source": [ + "### Behavior policy" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ac551e9e-f9ed-466e-8647-7ccf01cf07a7", + "metadata": {}, + "outputs": [], + "source": [ + "# vaso eps=0.5, mv abx optimal\n", + "π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)\n", + "π_beh[π_star == 1] = 0.9\n", + "π_beh[π_beh == 0.5] = 0.1" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "dadef7ec-8ecc-4c0e-9ed8-ccaa36d52413", + "metadata": {}, + "outputs": [], + "source": [ + "V_H_beh = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π_beh, gamma, H)\n", + "Q_H_beh = [(R + gamma * P.transpose((1,0,2)) @ V_H_beh[h]) for h in range(1,H)] + [R]" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "e3cdacbd-0a4c-4244-8823-d3f6ff1938e2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.2503835479385116" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "J_beh = isd @ V_H_beh[0]\n", + "J_beh" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "13d1316d-65c2-4e92-a5a0-44472aea3e10", + "metadata": {}, + "outputs": [], + "source": [ + "# Check recursive relationships\n", + "assert len(Q_H_beh) == H\n", + "assert len(V_H_beh) == H\n", + "assert np.all(Q_H_beh[-1] == R)\n", + "assert np.all(np.sum(π_beh * Q_H_beh[-1], axis=1) == V_H_beh[-1])\n", + "assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H_beh[-1] == Q_H_beh[-2])" + ] + }, + { + "cell_type": "markdown", + "id": "f76f21ef-f8b2-429e-8459-398f07e1ccc8", + "metadata": {}, + "source": [ + "### Evaluation policy" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "9bb8d86e-964d-4c07-9305-88a2f83018c4", + "metadata": {}, + "outputs": [], + "source": [ + "π_eval = π_star" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "08b7f857-e65e-44ab-8bfc-d4cd84230cf2", + "metadata": {}, + "outputs": [], + "source": [ + "V_H_eval = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π_eval, gamma, H)\n", + "Q_H_eval = [(R + gamma * P.transpose((1,0,2)) @ V_H_eval[h]) for h in range(1,H)] + [R]" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "fdc062d9-d2e6-449d-83b4-2d146bd2a381", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.40877179296760235" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "J_eval = isd @ V_H_eval[0]\n", + "J_eval" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "72a16631-b513-4bb6-8e57-f91d47752b04", + "metadata": {}, + "outputs": [], + "source": [ + "# Check recursive relationships\n", + "assert len(Q_H_eval) == H\n", + "assert len(V_H_eval) == H\n", + "assert np.all(Q_H_eval[-1] == R)\n", + "assert np.all(np.sum(π_eval * Q_H_eval[-1], axis=1) == V_H_eval[-1])\n", + "assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H_eval[-1] == Q_H_eval[-2])" + ] + }, + { + "cell_type": "markdown", + "id": "c5648b28-b673-430e-87b0-363a0df84eb3", + "metadata": { + "tags": [] + }, + "source": [ + "## Pre-populate counterfactual annotations for offline dataset" + ] + }, + { + "cell_type": "markdown", + "id": "de81fdc5-bb50-4636-b56b-f89f3a2f6c14", + "metadata": {}, + "source": [ + "### version 1: annotate counterfactuals only for initial states" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "4fe8811a-3107-44ae-bb77-bee3c414cc61", + "metadata": {}, + "outputs": [], + "source": [ + "df_va = df_seed2[['pt_id', 'Time', 'State', 'Action', 'Reward', 'NextState']].copy()\n", + "\n", + "# assign alternative action for vaso\n", + "df_va['Action:Vaso'] = df_va['Action'] % 2\n", + "df_va.loc[df_va['Action'] == -1, 'Action:Vaso'] = -1\n", + "df_va['Action_Alt'] = df_va['Action'] + 1 - 2*df_va['Action:Vaso']\n", + "df_va.loc[df_va['Action'] == -1, 'Action_Alt'] = -1" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "55665ad7-e617-4b30-bfc7-8f7cee20048f", + "metadata": {}, + "outputs": [], + "source": [ + "# for each original traj, create a length 1 pseudo-traj by flipping the action from starting state\n", + "def _func_v1(df_i):\n", + " df_i_new = []\n", + " for t in range(len(df_i) - 1):\n", + " if t > 0: break\n", + " step, S, A_alt = df_i.iloc[t]['Time'], df_i.iloc[t]['State'], df_i.iloc[t]['Action_Alt']\n", + " df_i_t = df_i.iloc[:t].loc[:, ['Time', 'State', 'Action', 'Reward', 'NextState']].append(\n", + " pd.Series({\n", + " 'Time': step,\n", + " 'State': S,\n", + " 'Action': A_alt,\n", + " 'Reward': Q_H_eval[t][S, A_alt],\n", + " 'NextState': 1442, # truncation indicator\n", + " }), ignore_index=True,\n", + " )\n", + " df_i_t['pt_id'] = df_i['pt_id'].iloc[0] + (t+1)*0.01\n", + " df_i_t = df_i_t[['pt_id', 'Time', 'State', 'Action', 'Reward', 'NextState']]\n", + " df_i_new.append(df_i_t)\n", + " return df_i_new" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "f1f945bb-5dc3-44b6-af82-26b73a0189dc", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 100000/100000 [13:01<00:00, 128.00it/s]\n" + ] + } + ], + "source": [ + "df_va_new1 = Parallel(n_jobs=50)(delayed(_func_v1)(df_i) for i, df_i in tqdm(df_va.groupby('pt_id')))\n", + "df_va_new1 = pd.concat(itertools.chain.from_iterable(df_va_new1)).reset_index(drop=True)\n", + "df_va_new1[['Time', 'State', 'Action', 'NextState']] = df_va_new1[['Time', 'State', 'Action', 'NextState']].astype(int)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "50c972fa-e297-4f3d-9d13-438f06153b25", + "metadata": {}, + "outputs": [], + "source": [ + "df_va_all1 = pd.concat([\n", + " df_va[['pt_id', 'Time', 'State', 'Action', 'Reward', 'NextState']], \n", + " df_va_new1]).sort_values(by=['pt_id', 'Time']).reset_index(drop=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "0485ac08-cc67-43cd-afba-636bdb830588", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
pt_idTimeStateActionRewardNextState
0200000.00033970.000000463
1200000.00146360.000000381
2200000.00238100.000000376
3200000.003376-11.0000001441
4200000.01033960.6785551442
.....................
1475561299999.001636560.000000365
1475562299999.001736560.000000365
1475563299999.001836560.000000365
1475564299999.001936560.000000-1
1475565299999.01028660.2412581442
\n", + "

1475566 rows × 6 columns

\n", + "
" + ], + "text/plain": [ + " pt_id Time State Action Reward NextState\n", + "0 200000.00 0 339 7 0.000000 463\n", + "1 200000.00 1 463 6 0.000000 381\n", + "2 200000.00 2 381 0 0.000000 376\n", + "3 200000.00 3 376 -1 1.000000 1441\n", + "4 200000.01 0 339 6 0.678555 1442\n", + "... ... ... ... ... ... ...\n", + "1475561 299999.00 16 365 6 0.000000 365\n", + "1475562 299999.00 17 365 6 0.000000 365\n", + "1475563 299999.00 18 365 6 0.000000 365\n", + "1475564 299999.00 19 365 6 0.000000 -1\n", + "1475565 299999.01 0 286 6 0.241258 1442\n", + "\n", + "[1475566 rows x 6 columns]" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_va_all1" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "a4bcde0b-d3e2-43da-89eb-36ccd4b3c3de", + "metadata": {}, + "outputs": [], + "source": [ + "df_va_all1.to_pickle('results/vaso_eps_0_1-evalOpt_df_seed2_aug_init.pkl')" + ] + }, + { + "cell_type": "markdown", + "id": "a001b005-0f47-4675-8723-af2c3ea9fa27", + "metadata": {}, + "source": [ + "### version 2: annotate counterfactuals for all time steps" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "41600dc9-feb6-4e92-9ff8-31c0744056ab", + "metadata": {}, + "outputs": [], + "source": [ + "df_va = df_seed2[['pt_id', 'Time', 'State', 'Action', 'Reward', 'NextState']].copy()\n", + "\n", + "# assign alternative action for vaso\n", + "df_va['Action:Vaso'] = df_va['Action'] % 2\n", + "df_va.loc[df_va['Action'] == -1, 'Action:Vaso'] = -1\n", + "df_va['Action_Alt'] = df_va['Action'] + 1 - 2*df_va['Action:Vaso']\n", + "df_va.loc[df_va['Action'] == -1, 'Action_Alt'] = -1" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "19bf115c-e92e-46e8-9c12-d72d5bb53d8f", + "metadata": {}, + "outputs": [], + "source": [ + "def _func_v2(df_i):\n", + " df_i_new = []\n", + " for t in range(len(df_i) - 1):\n", + " step, S, A_alt = df_i.iloc[t]['Time'], df_i.iloc[t]['State'], df_i.iloc[t]['Action_Alt']\n", + " df_i_t = df_i.iloc[:t].loc[:, ['Time', 'State', 'Action', 'Reward', 'NextState']].append(\n", + " pd.Series({\n", + " 'Time': step,\n", + " 'State': S,\n", + " 'Action': A_alt,\n", + " 'Reward': Q_H_eval[t][S, A_alt],\n", + " 'NextState': 1442, # truncation indicator\n", + " }), ignore_index=True,\n", + " )\n", + " df_i_t['pt_id'] = df_i['pt_id'].iloc[0] + (t+1)*0.01\n", + " df_i_t = df_i_t[['pt_id', 'Time', 'State', 'Action', 'Reward', 'NextState']]\n", + " df_i_new.append(df_i_t)\n", + " return df_i_new" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "e0d3abfe-d236-446c-b43b-f8a71d7f4e77", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 100000/100000 [29:24<00:00, 56.69it/s]\n" + ] + } + ], + "source": [ + "df_va_new2 = Parallel(n_jobs=100)(delayed(_func_v2)(df_i) for i, df_i in tqdm(df_va.groupby('pt_id')))\n", + "df_va_new2 = pd.concat(itertools.chain.from_iterable(df_va_new2)).reset_index(drop=True)\n", + "df_va_new2[['Time', 'State', 'Action', 'NextState']] = df_va_new2[['Time', 'State', 'Action', 'NextState']].astype(int)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "2ec67306-3994-40aa-bba5-296b4ed4b410", + "metadata": {}, + "outputs": [], + "source": [ + "df_va_all2 = pd.concat([\n", + " df_va[['pt_id', 'Time', 'State', 'Action', 'Reward', 'NextState']], \n", + " df_va_new2]).sort_values(by=['pt_id', 'Time']).reset_index(drop=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93f14a15", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
pt_idTimeStateActionRewardNextState
0200000.00033970.000000463
1200000.00146360.000000381
2200000.00238100.000000376
3200000.003376-11.0000001441
4200000.01033960.6785551442
.....................
12463512299999.191436560.000000365
12463513299999.191536560.000000365
12463514299999.191636560.000000365
12463515299999.191736560.000000365
12463516299999.191836570.0000001442
\n", + "

12463517 rows × 6 columns

\n", + "
" + ], + "text/plain": [ + " pt_id Time State Action Reward NextState\n", + "0 200000.00 0 339 7 0.000000 463\n", + "1 200000.00 1 463 6 0.000000 381\n", + "2 200000.00 2 381 0 0.000000 376\n", + "3 200000.00 3 376 -1 1.000000 1441\n", + "4 200000.01 0 339 6 0.678555 1442\n", + "... ... ... ... ... ... ...\n", + "12463512 299999.19 14 365 6 0.000000 365\n", + "12463513 299999.19 15 365 6 0.000000 365\n", + "12463514 299999.19 16 365 6 0.000000 365\n", + "12463515 299999.19 17 365 6 0.000000 365\n", + "12463516 299999.19 18 365 7 0.000000 1442\n", + "\n", + "[12463517 rows x 6 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df_va_all2" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "f34774c1-f442-4a64-8c2e-6ec0ee511a74", + "metadata": {}, + "outputs": [], + "source": [ + "df_va_all2.to_pickle('results/vaso_eps_0_1-evalOpt_df_seed2_aug_step.pkl')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b9d4ea0-3a83-4aa4-b6a5-f0d3c3522784", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "RL_venv", + "language": "python", + "name": "rl_venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sepsisSim/experiments/OPE_utils_new.py b/sepsisSim/experiments/OPE_utils_new.py new file mode 100644 index 0000000..d72afcb --- /dev/null +++ b/sepsisSim/experiments/OPE_utils_new.py @@ -0,0 +1,177 @@ +import numpy as np +import pandas as pd +import numpy_indexed as npi +import joblib +from tqdm import tqdm +import itertools +import copy + +NSTEPS = H = 20 # max episode length in historical data +G_min = -1 # the minimum possible return +G_max = 1 # the maximum possible return +nS, nA = 1442, 8 + +################## +## Preparations ## +################## + +def format_data_tensor(df_data, id_col='pt_id'): + """ + Converts data from a dataframe to a tensor + - df_data: pd.DataFrame with columns [id_col, Time, State, Action, Reward, NextState] + - id_col specifies the index column to group episodes + - data_tensor: integer tensor of shape (N, NSTEPS, 5) with the last last dimension being [t, s, a, r, s'] + """ + data_dict = dict(list(df_data.groupby(id_col))) + N = len(data_dict) + data_tensor = np.zeros((N, NSTEPS, 5), dtype=float) + data_tensor[:, :, 2] = -1 # initialize all actions to -1 + data_tensor[:, :, 1] = -1 # initialize all states to -1 + data_tensor[:, :, 4] = -1 # initialize all next states to -1 + + for i, (pt_id, df_values) in tqdm(enumerate(data_dict.items()), disable=True): + values = df_values.set_index(id_col).values + data_tensor[i, :len(values), :] = values + return data_tensor + +def compute_behavior_policy(df_data): + """ + Calculate probabilities of the behavior policy π_b + using Maximum Likelihood Estimation (MLE) + """ + # Compute empirical behavior policy from data + π_b = np.zeros((nS, nA)) + sa_counts = df_data.groupby(['State', 'Action']).count()[['Reward']].rename(columns={'Reward': 'count'}).reset_index() + + for i, row in sa_counts.iterrows(): + s, a = row['State'], row['Action'] + count = row['count'] + if row['Action'] == -1: + π_b[s, :] = count + else: + π_b[s, a] = count + + # assume uniform action probabilities in unobserved states + unobserved_states = (π_b.sum(axis=-1) == 0) + π_b[unobserved_states, :] = 1 + + # normalize action probabilities + π_b = π_b / π_b.sum(axis=-1, keepdims=True) + + return π_b + +def compute_behavior_policy_h(df_data): + """ + Calculate probabilities of the behavior policy π_b + using Maximum Likelihood Estimation (MLE) + """ + # Compute empirical behavior policy from data + πh_b = np.zeros((H, nS, nA)) + hsa_counts = df_data.groupby(['Time', 'State', 'Action']).count()[['Reward']].rename(columns={'Reward': 'count'}).reset_index() + + for i, row in hsa_counts.iterrows(): + h, s, a = row['Time'], row['State'], row['Action'] + count = row['count'] + if row['Action'] == -1: + πh_b[h, s, :] = count + else: + πh_b[h, s, a] = count + + # assume uniform action probabilities in unobserved states + unobserved_states = (πh_b.sum(axis=-1) == 0) + πh_b[unobserved_states, :] = 1 + + # normalize action probabilities + πh_b = πh_b / πh_b.sum(axis=-1, keepdims=True) + + return πh_b + +######################### +## Evaluating a policy ## +######################### + +def policy_eval_analytic(P, R, π, γ): + """ + Given the MDP model transition probability P (S,A,S) and reward function R (S,A), + Compute the value function of a stochastic policy π (S,A) using matrix inversion + + V_π = (I - γ P_π)^-1 R_π + """ + nS, nA = R.shape + R_π = np.sum(R * π, axis=1) + P_π = np.sum(P * np.expand_dims(π, 2), axis=1) + V_π = np.linalg.inv(np.eye(nS) - γ * P_π) @ R_π + return V_π + +def policy_eval_analytic_finite(P, R, π, γ, H): + """ + Given the MDP model transition probability P (S,A,S) and reward function R (S,A), + Compute the value function of a stochastic policy π (S,A) using the power series formula + Horizon h=1...H + V_π(h) = R_π + γ P_π R_π + ... + γ^{h-1} P_π^{h-1} R_π + """ + nS, nA = R.shape + R_π = np.sum(R * π, axis=1) + P_π = np.sum(P * np.expand_dims(π, 2), axis=1) + V_π = [R_π] + for h in range(1,H): + V_π.append(R_π + γ * P_π @ V_π[-1]) + return list(reversed(V_π)) + +def OPE_IS_h(data, π_b, π_e, γ, epsilon=0.01): + """ + - π_b, π_e: behavior/evaluation policy, shape (S,A) + """ + # Get a soft version of the evaluation policy for WIS + π_e_soft = np.copy(π_e).astype(float) + π_e_soft[π_e_soft == 1] = (1 - epsilon) + π_e_soft[π_e_soft == 0] = epsilon / (nA - 1) + + # Apply WIS + return _is_h(data, π_b, π_e_soft, γ) + +def _is_h(data, π_b, π_e, γ): + """ + Weighted Importance Sampling for Off-Policy Evaluation + - data: tensor of shape (N, T, 5) with the last last dimension being [t, s, a, r, s'] + - π_b: behavior policy + - π_e: evaluation policy (aka target policy) + - γ: discount factor + """ + t_list = data[..., 0].astype(int) + s_list = data[..., 1].astype(int) + a_list = data[..., 2].astype(int) + r_list = data[..., 3].astype(float) + + # Per-trajectory returns (discounted cumulative rewards) + G = (r_list * np.power(γ, t_list)).sum(axis=-1) + + # Per-transition importance ratios + p_b = π_b[t_list, s_list, a_list] + p_e = π_e[s_list, a_list] + + # Deal with variable length sequences by setting ratio to 1 + terminated_idx = (a_list == -1) + p_b[terminated_idx] = 1 + p_e[terminated_idx] = 1 + + if not np.all(p_b > 0): + import pdb + pdb.set_trace() + assert np.all(p_b > 0), "Some actions had zero prob under p_b, WIS fails" + + # Per-trajectory cumulative importance ratios, take the product + rho = (p_e / p_b).prod(axis=1) + rho_norm = rho / rho.sum() + + # directly calculate weighted average over trajectories + is_value = np.average(G*rho) # (G @ rho) / len(G) + wis_value = np.average(G, weights=rho) # (G @ rho_norm) + ess1 = 1 / (rho_norm ** 2).sum() + ess1_ = (rho.sum()) ** 2 / ((rho ** 2).sum()) + assert np.isclose(ess1, ess1_) + ess2 = 1. / rho_norm.max() + return is_value, wis_value, { + 'ESS1': ess1, 'ESS2': ess2, 'G': G, + 'rho': rho, 'rho_norm': rho_norm + } diff --git a/sepsisSim/experiments/commands.sh b/sepsisSim/experiments/commands.sh new file mode 100644 index 0000000..b50a32f --- /dev/null +++ b/sepsisSim/experiments/commands.sh @@ -0,0 +1,113 @@ +mkdir -p results/exp-FINAL-1/ +mkdir -p results/exp-FINAL-2/ +mkdir -p results/exp-FINAL-3/ +mkdir -p results/exp-FINAL-4/ +mkdir -p results/exp-FINAL-5/ + +python exp-1-observed.py & +python exp-1-onpolicy-baseline.py & + +## Baseline +python exp-1-baseline.py --flip_num=0 --flip_seed=0 & +for FLIP_NUM in 25 50 100 200 300 400; do for SEED in 0 42 123 424242 10000; do +python exp-1-baseline.py --flip_num=$FLIP_NUM --flip_seed=$SEED &> /dev/null & +done; done + + +## Naive implementation + +# Naive weighted, Annot π_e +python exp-1-Naive.py --flip_num=0 --flip_seed=0 & +for FLIP_NUM in 25 50 100 200 300 400; do for SEED in 0 42 123 424242 10000; do +python exp-1-Naive.py --flip_num=$FLIP_NUM --flip_seed=$SEED &> /dev/null & +done; done + +# Naive UnWeighted, Annot π_e +python exp-1-NaiveUW.py --flip_num=0 --flip_seed=0 & +for FLIP_NUM in 25 50 100 200 300 400; do for SEED in 0 42 123 424242 10000; do +python exp-1-NaiveUW.py --flip_num=$FLIP_NUM --flip_seed=$SEED &> /dev/null & +done; done + + +## Poposed approach + +# Annot π_e +python exp-2-annotEval.py --flip_num=0 --flip_seed=0 & +for FLIP_NUM in 25 50 100 200 300 400; do for SEED in 0 42 123 424242 10000; do +python exp-2-annotEval.py --flip_num=$FLIP_NUM --flip_seed=$SEED &> /dev/null & +done; done + +# Annot π_b +python exp-2-annotBeh.py --flip_num=0 --flip_seed=0 & +for FLIP_NUM in 25 50 100 200 300 400; do for SEED in 0 42 123 424242 10000; do +python exp-2-annotBeh.py --flip_num=$FLIP_NUM --flip_seed=$SEED &> /dev/null & +done; done + +# Annot zero +python exp-2-annotZero.py --flip_num=0 --flip_seed=0 & +for FLIP_NUM in 25 50 100 200 300 400; do for SEED in 0 42 123 424242 10000; do +python exp-2-annotZero.py --flip_num=$FLIP_NUM --flip_seed=$SEED &> /dev/null & +done; done + + +########## + +## Noisy annotations +for NOISE in '0.0' '0.1' '0.2' '0.3' '0.4' '0.5' '0.6' '0.7' '0.8' '0.9' '1.0'; do +python runs-3-annotEvalNoise.py --flip_num=0 --flip_seed=0 --annot_noise=$NOISE &> /dev/null & +done + +for FLIP_NUM in 25 50 100 200 300 400; do for SEED in 0 42 123 424242 10000; do +for NOISE in '0.0' '0.1' '0.2' '0.3' '0.4' '0.5' '0.6' '0.7' '0.8' '0.9' '1.0'; do +python runs-3-annotEvalNoise.py --flip_num=$FLIP_NUM --flip_seed=$SEED --annot_noise=$NOISE &> /dev/null; +done & +done; done + + + +########## + +## Missing annotations without and with imputation +for RATIO in '0.0' '0.1' '0.2' '0.3' '0.4' '0.5' '0.6' '0.7' '0.8' '0.9' '1.0'; do +python runs-4-annotEvalMissing.py --flip_num=0 --flip_seed=0 --annot_ratio=$RATIO &> /dev/null; +python runs-4-annotEvalMissingImpute.py --flip_num=0 --flip_seed=0 --annot_ratio=$RATIO &> /dev/null; +done & + +for FLIP_NUM in 25 50 100 200 300 400; do for SEED in 0 42 123 424242 10000; do +for RATIO in '0.0' '0.1' '0.2' '0.3' '0.4' '0.5' '0.6' '0.7' '0.8' '0.9' '1.0'; do +python runs-4-annotEvalMissing.py --flip_num=$FLIP_NUM --flip_seed=$SEED --annot_ratio=$RATIO &> /dev/null; +python runs-4-annotEvalMissingImpute.py --flip_num=$FLIP_NUM --flip_seed=$SEED --annot_ratio=$RATIO &> /dev/null; +done & +done; done + + + +########## + +## Annot π_b converted approx MDP +mkdir -p './results/runs-2_v4e/' +python runs-5-annotBehConvertedAM.py --flip_num=0 --flip_seed=0 &> /dev/null & +for FLIP_NUM in 25 50 100 200 300 400; do +for SEED in 0 42 123 424242 10000; do +python runs-5-annotBehConvertedAM.py --flip_num=$FLIP_NUM --flip_seed=$SEED &> /dev/null +done & +done + + + + + +########## + +## Noisy annotations, only 10% annotated without and with imputation +for NOISE in '0.0' '0.1' '0.3' '0.4' '0.5' '0.6' '0.7' '0.8' '0.9' '1.0'; do +python runs-4-annotEvalMissing.py --flip_num=0 --flip_seed=0 --annot_ratio='0.1' --annot_noise=$NOISE &> /dev/null; +python runs-4-annotEvalMissingImpute.py --flip_num=0 --flip_seed=0 --annot_ratio='0.1' --annot_noise=$NOISE &> /dev/null; +done & + +for FLIP_NUM in 25 50 100 200 300 400; do for SEED in 0 42 123 424242 10000; do +for NOISE in '0.0' '0.1' '0.3' '0.4' '0.5' '0.6' '0.7' '0.8' '0.9' '1.0'; do +python runs-4-annotEvalMissing.py --flip_num=$FLIP_NUM --flip_seed=$SEED --annot_ratio='0.1' --annot_noise=$NOISE &> /dev/null; +python runs-4-annotEvalMissingImpute.py --flip_num=$FLIP_NUM --flip_seed=$SEED --annot_ratio='0.1' --annot_noise=$NOISE &> /dev/null; +done & +done; done diff --git a/sepsisSim/experiments/exp-1-Naive.py b/sepsisSim/experiments/exp-1-Naive.py new file mode 100644 index 0000000..e99fd26 --- /dev/null +++ b/sepsisSim/experiments/exp-1-Naive.py @@ -0,0 +1,294 @@ +# ## Simulation parameters +exp_name = 'exp-FINAL-1' +eps = 0.10 +eps_str = '0_1' + +run_idx_length = 1_000 +N_val = 1_000 +runs = 50 + +# Number of action-flipped states +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--flip_num', type=int) +parser.add_argument('--flip_seed', type=int) +args = parser.parse_args() +pol_flip_num = args.flip_num +pol_flip_seed = args.flip_seed + +pol_name = f'flip{pol_flip_num}_seed{pol_flip_seed}' +out_fname = f'./results/{exp_name}/vaso_eps_{eps_str}-{pol_name}-aug_step-Naive.csv' + +import numpy as np +import pandas as pd + +df_tmp = None +try: + df_tmp = pd.read_csv(out_fname) +except: + pass + +if df_tmp is not None: + print('File exists') + quit() + +from tqdm import tqdm +from collections import defaultdict +import pickle +import itertools +import copy +import random +import itertools +import joblib +from joblib import Parallel, delayed + +from OPE_utils_new import ( + format_data_tensor, + policy_eval_analytic_finite, + OPE_IS_h, + compute_behavior_policy_h, +) + + +def policy_eval_helper(π): + V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H) + Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R] + J = isd @ V_H[0] + # Check recursive relationships + assert len(Q_H) == H + assert len(V_H) == H + assert np.all(Q_H[-1] == R) + assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1]) + assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2]) + return V_H, Q_H, J + + +NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP +G_min = -1 # the minimum possible return +G_max = 1 # the maximum possible return +nS, nA = 1442, 8 + +PROB_DIAB = 0.2 + +# Ground truth MDP model +MDP_parameters = joblib.load('../data/MDP_parameters.joblib') +P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next) +R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A) +nS, nA = R.shape +gamma = 0.99 + +# unif rand isd, mixture of diabetic state +isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib') +isd = (isd > 0).astype(float) +isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB) +isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB) + +# Precomputed optimal policy +π_star = joblib.load('../data/π_star.joblib') + + + +# ## Load data +input_dir = f'../datagen/vaso_eps_{eps_str}-100k/' + +def load_data(fname): + print('Loading data', fname, '...', end='') + df_data = pd.read_csv('{}/{}'.format(input_dir, fname)).rename(columns={'State_idx': 'State'})#[['pt_id', 'Time', 'State', 'Action', 'Reward']] + + # Assign next state + df_data['NextState'] = [*df_data['State'].iloc[1:].values, -1] + df_data.loc[df_data.groupby('pt_id')['Time'].idxmax(), 'NextState'] = -1 + df_data.loc[(df_data['Reward'] == -1), 'NextState'] = 1440 + df_data.loc[(df_data['Reward'] == 1), 'NextState'] = 1441 + + assert ((df_data['Reward'] != 0) == (df_data['Action'] == -1)).all() + + print('DONE') + return df_data + + +# df_seed1 = load_data('1-features.csv') # tr +df_seed2 = load_data('2-features.csv') # va + + +# ## Naive CF-OPE code + +def compute_behavior_policy_weighted(df_data, trajW): + """ + Calculate probabilities of the behavior policy π_b + using Maximum Likelihood Estimation (MLE) + """ + # Compute empirical behavior policy from data + π_b = np.zeros((nS, nA)) + df_dataW = df_data.set_index('pt_id').join( + pd.DataFrame(trajW.sum(axis=1), + index=sorted(df_data['pt_id'].unique()), columns=['Weight']) + ).reset_index() + + sa_counts = df_dataW.groupby(['State', 'Action'])[['Weight']].sum().rename(columns={'Weight': 'count'}).reset_index() + # df_data.groupby(['State', 'Action']).count()[['Reward']].rename(columns={'Reward': 'count'}).reset_index() + + try: + for i, row in sa_counts.iterrows(): + s, a = int(row['State']), int(row['Action']) + count = row['count'] + if row['Action'] == -1: + π_b[s, :] = count + else: + π_b[s, a] = count + except: + print(s,a) + raise + # import pdb + # pdb.set_trace() + + # assume uniform action probabilities in unobserved states + unobserved_states = (π_b.sum(axis=-1) == 0) + π_b[unobserved_states, :] = 1 + + # normalize action probabilities + π_b = π_b / π_b.sum(axis=-1, keepdims=True) + + return π_b + + +def OPE_IS_trajW(data, π_b, π_e, γ, epsilon=0.01, trajW=None): + """ + - π_b, π_e: behavior/evaluation policy, shape (S,A) + """ + # Get a soft version of the evaluation policy for WIS + π_e_soft = np.copy(π_e).astype(float) + π_e_soft[π_e_soft == 1] = (1 - epsilon) + π_e_soft[π_e_soft == 0] = epsilon / (nA - 1) + + # Apply WIS + return _is_trajW(data, π_b, π_e_soft, γ, np.eye(len(data)) if trajW is None else trajW) + +def _is_trajW(data, π_b, π_e, γ, trajW): + """ + Weighted Importance Sampling for Off-Policy Evaluation + - data: tensor of shape (N, T, 5) with the last last dimension being [t, s, a, r, s'] + - π_b: behavior policy + - π_e: evaluation policy (aka target policy) + - γ: discount factor + """ + t_list = data[..., 0].astype(int) + s_list = data[..., 1].astype(int) + a_list = data[..., 2].astype(int) + r_list = data[..., 3].astype(float) + + # Per-trajectory returns (discounted cumulative rewards) + G = (r_list * np.power(γ, t_list)).sum(axis=-1) + + # Per-transition importance ratios + p_b = π_b[s_list, a_list] + p_e = π_e[s_list, a_list] + + # Deal with variable length sequences by setting ratio to 1 + terminated_idx = (a_list == -1) + p_b[terminated_idx] = 1 + p_e[terminated_idx] = 1 + + # if not np.all(p_b > 0): + # import pdb + # pdb.set_trace() + # assert np.all(p_b > 0), "Some actions had zero prob under p_b, WIS fails" + + # Per-trajectory cumulative importance ratios, take the product + rho = (p_e / p_b).prod(axis=1) * trajW.sum(axis=1) + + # directly calculate weighted average over trajectories + is_value = np.nansum(G * rho) / trajW.shape[1] + wis_value = np.nansum(G * rho) / np.nansum(rho) + rho_norm = rho / np.nansum(rho) + rho_nna = rho[~np.isnan(rho)] + rho_norm_nna = rho_norm[~np.isnan(rho_norm)] + ess1 = 1 / np.nansum(rho_norm_nna ** 2) + ess1_ = (np.nansum(rho_nna)) ** 2 / (np.nansum(rho_nna ** 2)) + # assert np.isclose(ess1, ess1_) + ess2 = 1. / np.nanmax(rho_norm) + return is_value, wis_value, { + 'ESS1': ess1, 'ESS2': ess2, 'G': G, + 'rho': rho, 'rho_norm': rho_norm + } + + +# ## Policies + +# vaso unif, mv abx optimal +π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) + + +# ### Behavior policy + +# vaso eps=0.5, mv abx optimal +π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_beh[π_star == 1] = 1-eps +π_beh[π_beh == 0.5] = eps + +V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh) +J_beh + + +# ### Optimal policy +V_H_star, Q_H_star, J_star = policy_eval_helper(π_star) +J_star + + +# ### flip action for x% states + +rng_flip = np.random.default_rng(pol_flip_seed) +flip_states = rng_flip.choice(range(1440), pol_flip_num, replace=False) + +π_tmp = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_flip = π_tmp.copy() +π_flip[π_tmp == 0.5] = 0 +π_flip[π_star == 1] = 1 +for s in flip_states: + π_flip[s, π_tmp[s] == 0.5] = 1 + π_flip[s, π_star[s] == 1] = 0 +assert π_flip.sum(axis=1).mean() == 1 + +# np.savetxt(f'./results/{exp_name}/policy_{pol_name}.txt', π_flip) + + +# ## Compare OPE + +π_eval = π_flip + + +# ### Proposed: replace future with the value function for the evaluation policy + +df_va_all2 = pd.read_pickle(f'results/vaso_eps_{eps_str}-annotOpt_df_seed2_aug_step.pkl') +V_H_eval, Q_H_eval, J_eval = policy_eval_helper(π_eval) + + +df_results_v2 = [] +for run in range(runs): + df_va = df_va_all2.set_index('pt_id').loc[200000+run*run_idx_length:200000+run*run_idx_length + N_val - 1+0.999].reset_index() + + # m is num of trajectories + # (m_all, m_orig) table with binary indicator of the source (original traj) of each traj + df_idmap = df_va[['pt_id']].copy() + df_idmap['map_pt_id'] = df_va['pt_id'].apply(np.floor).astype(int) + df_idmap['mask'] = 1 + df_mapp = df_idmap.drop_duplicates().set_index(['pt_id', 'map_pt_id']).unstack().fillna(0).astype(pd.SparseDtype('int', 0)) + + # (m_all, m_orig) traj-wise weight matrix + # each column sums to 1 + traj_weight_matrix = df_mapp.values + traj_weight_matrix = (traj_weight_matrix / traj_weight_matrix.sum(axis=0)) + assert np.isclose(traj_weight_matrix.sum(axis=0), 1).all() + + # OPE - WIS/WDR prep + v2_data_va = format_data_tensor(df_va) + v2_pi_b_val = compute_behavior_policy_weighted(df_va, traj_weight_matrix) + + # OPE - WIS + v2_IS_value, v2_WIS_value, v2_ESS_info = OPE_IS_trajW(v2_data_va, v2_pi_b_val, π_eval, gamma, trajW=traj_weight_matrix) + + df_results_v2.append([v2_IS_value, v2_WIS_value, v2_ESS_info['ESS1'], v2_ESS_info['ESS2']]) + pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1', 'ESS2']).to_csv(out_fname) + +df_results_v2 = pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1', 'ESS2']) +df_results_v2.to_csv(out_fname) diff --git a/sepsisSim/experiments/exp-1-NaiveUW.py b/sepsisSim/experiments/exp-1-NaiveUW.py new file mode 100644 index 0000000..991e7c5 --- /dev/null +++ b/sepsisSim/experiments/exp-1-NaiveUW.py @@ -0,0 +1,301 @@ +# ## Simulation parameters +exp_name = 'exp-FINAL-1' +eps = 0.10 +eps_str = '0_1' + +run_idx_length = 1_000 +N_val = 1_000 +runs = 50 + +# Number of action-flipped states +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--flip_num', type=int) +parser.add_argument('--flip_seed', type=int) +args = parser.parse_args() +pol_flip_num = args.flip_num +pol_flip_seed = args.flip_seed + +pol_name = f'flip{pol_flip_num}_seed{pol_flip_seed}' +out_fname = f'./results/{exp_name}/vaso_eps_{eps_str}-{pol_name}-aug_step-NaiveUW.csv' + +import numpy as np +import pandas as pd + +df_tmp = None +try: + df_tmp = pd.read_csv(out_fname) +except: + pass + +if df_tmp is not None: + print('File exists') + quit() + +from tqdm import tqdm +from collections import defaultdict +import pickle +import itertools +import copy +import random +import itertools +import joblib +from joblib import Parallel, delayed + + +from OPE_utils import compute_behavior_policy +from OPE_utils_new import ( + format_data_tensor, + policy_eval_analytic_finite, + OPE_IS_h, + compute_behavior_policy_h, +) + + +def policy_eval_helper(π): + V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H) + Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R] + J = isd @ V_H[0] + # Check recursive relationships + assert len(Q_H) == H + assert len(V_H) == H + assert np.all(Q_H[-1] == R) + assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1]) + assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2]) + return V_H, Q_H, J + +def iqm(x): + return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None) + +NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP +G_min = -1 # the minimum possible return +G_max = 1 # the maximum possible return +nS, nA = 1442, 8 + +PROB_DIAB = 0.2 + +# Ground truth MDP model +MDP_parameters = joblib.load('../data/MDP_parameters.joblib') +P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next) +R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A) +nS, nA = R.shape +gamma = 0.99 + +# unif rand isd, mixture of diabetic state +isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib') +isd = (isd > 0).astype(float) +isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB) +isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB) + +# Precomputed optimal policy +π_star = joblib.load('../data/π_star.joblib') + + + +# ## Load data +input_dir = f'../datagen/vaso_eps_{eps_str}-100k/' + +def load_data(fname): + print('Loading data', fname, '...', end='') + df_data = pd.read_csv('{}/{}'.format(input_dir, fname)).rename(columns={'State_idx': 'State'})#[['pt_id', 'Time', 'State', 'Action', 'Reward']] + + # Assign next state + df_data['NextState'] = [*df_data['State'].iloc[1:].values, -1] + df_data.loc[df_data.groupby('pt_id')['Time'].idxmax(), 'NextState'] = -1 + df_data.loc[(df_data['Reward'] == -1), 'NextState'] = 1440 + df_data.loc[(df_data['Reward'] == 1), 'NextState'] = 1441 + + assert ((df_data['Reward'] != 0) == (df_data['Action'] == -1)).all() + + print('DONE') + return df_data + + +# df_seed1 = load_data('1-features.csv') # tr +df_seed2 = load_data('2-features.csv') # va + + +# ## Naive CF-OPE code + +def compute_behavior_policy_weighted(df_data, trajW=None): + """ + Calculate probabilities of the behavior policy π_b + using Maximum Likelihood Estimation (MLE) + """ + if trajW is None: + trajW = np.eye(len(df_data)) + + # Compute empirical behavior policy from data + π_b = np.zeros((nS, nA)) + df_dataW = df_data.set_index('pt_id').join( + pd.DataFrame(trajW.sum(axis=1), + index=sorted(df_data['pt_id'].unique()), columns=['Weight']) + ).reset_index() + + sa_counts = df_dataW.groupby(['State', 'Action'])[['Weight']].sum().rename(columns={'Weight': 'count'}).reset_index() + # df_data.groupby(['State', 'Action']).count()[['Reward']].rename(columns={'Reward': 'count'}).reset_index() + + try: + for i, row in sa_counts.iterrows(): + s, a = int(row['State']), int(row['Action']) + count = row['count'] + if row['Action'] == -1: + π_b[s, :] = count + else: + π_b[s, a] = count + except: + print(s,a) + raise + # import pdb + # pdb.set_trace() + + # assume uniform action probabilities in unobserved states + unobserved_states = (π_b.sum(axis=-1) == 0) + π_b[unobserved_states, :] = 1 + + # normalize action probabilities + π_b = π_b / π_b.sum(axis=-1, keepdims=True) + + return π_b + + +def OPE_IS_trajW(data, π_b, π_e, γ, epsilon=0.01, trajW=None): + """ + - π_b, π_e: behavior/evaluation policy, shape (S,A) + """ + # Get a soft version of the evaluation policy for WIS + π_e_soft = np.copy(π_e).astype(float) + π_e_soft[π_e_soft == 1] = (1 - epsilon) + π_e_soft[π_e_soft == 0] = epsilon / (nA - 1) + + # Apply WIS + return _is_trajW(data, π_b, π_e_soft, γ, np.eye(len(data)) if trajW is None else trajW) + +def _is_trajW(data, π_b, π_e, γ, trajW): + """ + Weighted Importance Sampling for Off-Policy Evaluation + - data: tensor of shape (N, T, 5) with the last last dimension being [t, s, a, r, s'] + - π_b: behavior policy + - π_e: evaluation policy (aka target policy) + - γ: discount factor + """ + t_list = data[..., 0].astype(int) + s_list = data[..., 1].astype(int) + a_list = data[..., 2].astype(int) + r_list = data[..., 3].astype(float) + + # Per-trajectory returns (discounted cumulative rewards) + G = (r_list * np.power(γ, t_list)).sum(axis=-1) + + # Per-transition importance ratios + p_b = π_b[s_list, a_list] + p_e = π_e[s_list, a_list] + + # Deal with variable length sequences by setting ratio to 1 + terminated_idx = (a_list == -1) + p_b[terminated_idx] = 1 + p_e[terminated_idx] = 1 + + # if not np.all(p_b > 0): + # import pdb + # pdb.set_trace() + # assert np.all(p_b > 0), "Some actions had zero prob under p_b, WIS fails" + + # Per-trajectory cumulative importance ratios, take the product + rho = (p_e / p_b).prod(axis=1) * trajW.sum(axis=1) + + # directly calculate weighted average over trajectories + is_value = np.nansum(G * rho) / trajW.shape[1] + wis_value = np.nansum(G * rho) / np.nansum(rho) + rho_norm = rho / np.nansum(rho) + rho_nna = rho[~np.isnan(rho)] + rho_norm_nna = rho_norm[~np.isnan(rho_norm)] + ess1 = 1 / np.nansum(rho_norm_nna ** 2) + ess1_ = (np.nansum(rho_nna)) ** 2 / (np.nansum(rho_nna ** 2)) + # assert np.isclose(ess1, ess1_) + ess2 = 1. / np.nanmax(rho_norm) + return is_value, wis_value, { + 'ESS1': ess1, 'ESS2': ess2, 'G': G, + 'rho': rho, 'rho_norm': rho_norm + } + + +# ## Policies + +# vaso unif, mv abx optimal +π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) + + +# ### Behavior policy + +# vaso eps=0.5, mv abx optimal +π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_beh[π_star == 1] = 1-eps +π_beh[π_beh == 0.5] = eps + +V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh) +J_beh + + +# ### Optimal policy +V_H_star, Q_H_star, J_star = policy_eval_helper(π_star) +J_star + + +# ### flip action for x% states + +rng_flip = np.random.default_rng(pol_flip_seed) +flip_states = rng_flip.choice(range(1440), pol_flip_num, replace=False) + +π_tmp = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_flip = π_tmp.copy() +π_flip[π_tmp == 0.5] = 0 +π_flip[π_star == 1] = 1 +for s in flip_states: + π_flip[s, π_tmp[s] == 0.5] = 1 + π_flip[s, π_star[s] == 1] = 0 +assert π_flip.sum(axis=1).mean() == 1 + +# np.savetxt(f'./results/{exp_name}/policy_{pol_name}.txt', π_flip) + + +# ## Compare OPE + +π_eval = π_flip + + +# ### Proposed: replace future with the value function for the evaluation policy + +df_va_all2 = pd.read_pickle(f'results/vaso_eps_{eps_str}-annotOpt_df_seed2_aug_step.pkl') +V_H_eval, Q_H_eval, J_eval = policy_eval_helper(π_eval) + + +df_results_v2 = [] +for run in tqdm(range(runs)): + df_va = df_va_all2.set_index('pt_id').loc[200000+run*run_idx_length:200000+run*run_idx_length + N_val - 1+0.999].reset_index() + +# # m is num of trajectories +# # (m_all, m_orig) table with binary indicator of the source (original traj) of each traj +# df_idmap = df_va[['pt_id']].copy() +# df_idmap['map_pt_id'] = df_va['pt_id'].apply(np.floor).astype(int) +# df_idmap['mask'] = 1 +# df_mapp = df_idmap.drop_duplicates().set_index(['pt_id', 'map_pt_id']).unstack().fillna(0).astype(pd.SparseDtype('int', 0)) + +# # (m_all, m_orig) traj-wise weight matrix +# # each column sums to 1 +# traj_weight_matrix = df_mapp.values +# traj_weight_matrix = (traj_weight_matrix / traj_weight_matrix.sum(axis=0)) +# assert np.isclose(traj_weight_matrix.sum(axis=0), 1).all() + + # OPE - WIS/WDR prep + v2_data_va = format_data_tensor(df_va) + v2_pi_b_val = compute_behavior_policy(df_va) + + # OPE - WIS + v2_IS_value, v2_WIS_value, v2_ESS_info = OPE_IS_trajW(v2_data_va, v2_pi_b_val, π_eval, gamma) + + df_results_v2.append([v2_IS_value, v2_WIS_value, v2_ESS_info['ESS1'], v2_ESS_info['ESS2']]) + pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1', 'ESS2']).to_csv(out_fname) + +df_results_v2 = pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1', 'ESS2']) +df_results_v2.to_csv(out_fname) diff --git a/sepsisSim/experiments/exp-1-baseline.py b/sepsisSim/experiments/exp-1-baseline.py new file mode 100644 index 0000000..b32f3a6 --- /dev/null +++ b/sepsisSim/experiments/exp-1-baseline.py @@ -0,0 +1,178 @@ +# ## Simulation parameters +exp_name = 'exp-FINAL-1' +eps = 0.10 +eps_str = '0_1' + +run_idx_length = 1_000 +N_val = 1_000 +runs = 50 + +# Number of action-flipped states +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--flip_num', type=int) +parser.add_argument('--flip_seed', type=int) +args = parser.parse_args() +pol_flip_num = args.flip_num +pol_flip_seed = args.flip_seed + +pol_name = f'flip{pol_flip_num}_seed{pol_flip_seed}' +out_fname = f'./results/{exp_name}/vaso_eps_{eps_str}-{pol_name}-orig.csv' + +import numpy as np +import pandas as pd + +df_tmp = None +try: + df_tmp = pd.read_csv(out_fname) +except: + pass + +if df_tmp is not None: + print('File exists') + quit() + +from tqdm import tqdm +from collections import defaultdict +import pickle +import itertools +import copy +import random +import itertools +import joblib +from joblib import Parallel, delayed + +from OPE_utils_new import ( + format_data_tensor, + policy_eval_analytic_finite, + OPE_IS_h, + compute_behavior_policy_h, +) + + +def policy_eval_helper(π): + V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H) + Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R] + J = isd @ V_H[0] + # Check recursive relationships + assert len(Q_H) == H + assert len(V_H) == H + assert np.all(Q_H[-1] == R) + assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1]) + assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2]) + return V_H, Q_H, J + +def iqm(x): + return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None) + +NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP +G_min = -1 # the minimum possible return +G_max = 1 # the maximum possible return +nS, nA = 1442, 8 + +PROB_DIAB = 0.2 + +# Ground truth MDP model +MDP_parameters = joblib.load('../data/MDP_parameters.joblib') +P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next) +R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A) +nS, nA = R.shape +gamma = 0.99 + +# unif rand isd, mixture of diabetic state +isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib') +isd = (isd > 0).astype(float) +isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB) +isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB) + +# Precomputed optimal policy +π_star = joblib.load('../data/π_star.joblib') + + + +# ## Load data +input_dir = f'../datagen/vaso_eps_{eps_str}-100k/' + +def load_data(fname): + print('Loading data', fname, '...', end='') + df_data = pd.read_csv('{}/{}'.format(input_dir, fname)).rename(columns={'State_idx': 'State'})#[['pt_id', 'Time', 'State', 'Action', 'Reward']] + + # Assign next state + df_data['NextState'] = [*df_data['State'].iloc[1:].values, -1] + df_data.loc[df_data.groupby('pt_id')['Time'].idxmax(), 'NextState'] = -1 + df_data.loc[(df_data['Reward'] == -1), 'NextState'] = 1440 + df_data.loc[(df_data['Reward'] == 1), 'NextState'] = 1441 + + assert ((df_data['Reward'] != 0) == (df_data['Action'] == -1)).all() + + print('DONE') + return df_data + + +# df_seed1 = load_data('1-features.csv') # tr +df_seed2 = load_data('2-features.csv') # va + + +# ## Policies + +# vaso unif, mv abx optimal +π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) + + +# ### Behavior policy + +# vaso eps=0.5, mv abx optimal +π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_beh[π_star == 1] = 1-eps +π_beh[π_beh == 0.5] = eps + +V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh) +J_beh + + +# ### Optimal policy +V_H_star, Q_H_star, J_star = policy_eval_helper(π_star) +J_star + + +# ### flip action for x% states + +rng_flip = np.random.default_rng(pol_flip_seed) +flip_states = rng_flip.choice(range(1440), pol_flip_num, replace=False) + +π_tmp = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_flip = π_tmp.copy() +π_flip[π_tmp == 0.5] = 0 +π_flip[π_star == 1] = 1 +for s in flip_states: + π_flip[s, π_tmp[s] == 0.5] = 1 + π_flip[s, π_star[s] == 1] = 0 +assert π_flip.sum(axis=1).mean() == 1 + +np.savetxt(f'./results/{exp_name}/policy_{pol_name}.txt', π_flip) + + +# ## Compare OPE + +π_eval = π_flip + + +# ### Baseline IS: original dataset + +df_results = [] +for run in range(runs): + df_va = df_seed2.set_index('pt_id').loc[200000+run*run_idx_length:200000+run*run_idx_length + N_val - 1].reset_index()[ + ['pt_id', 'Time', 'State', 'Action', 'Reward', 'NextState'] + ] + df = df_va[['pt_id', 'Time', 'State', 'Action', 'Reward', 'NextState']] + + # OPE - WIS/WDR prep + data_va = format_data_tensor(df) + pi_b_va = compute_behavior_policy_h(df) + + # OPE - IS + IS_value, WIS_value, ESS_info = OPE_IS_h(data_va, pi_b_va, π_eval, gamma, epsilon=0.0) + df_results.append([IS_value, WIS_value, ESS_info['ESS1'], ESS_info['ESS2']]) + +df_results = pd.DataFrame(df_results, columns=['IS_value', 'WIS_value', 'ESS1', 'ESS2']) +df_results.to_csv(out_fname, index=False) diff --git a/sepsisSim/experiments/exp-1-observed.py b/sepsisSim/experiments/exp-1-observed.py new file mode 100644 index 0000000..7e238e4 --- /dev/null +++ b/sepsisSim/experiments/exp-1-observed.py @@ -0,0 +1,143 @@ +# ## Simulation parameters +exp_name = 'exp-FINAL-1' +eps = 0.10 +eps_str = '0_1' + +run_idx_length = 1_000 +N_val = 1_000 +runs = 50 + + +out_fname = f'./results/{exp_name}/vaso_eps_{eps_str}-observed.csv' + +import numpy as np +import pandas as pd +from tqdm import tqdm +from collections import defaultdict +import pickle +import itertools +import copy +import random +import matplotlib.pyplot as plt +import seaborn as sns +import scipy.stats +from sklearn import metrics +import itertools + +import joblib +from joblib import Parallel, delayed + +from OPE_utils_new import ( + format_data_tensor, + policy_eval_analytic_finite, + OPE_IS_h, + compute_behavior_policy_h, +) + + +def policy_eval_helper(π): + V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H) + Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R] + J = isd @ V_H[0] + # Check recursive relationships + assert len(Q_H) == H + assert len(V_H) == H + assert np.all(Q_H[-1] == R) + assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1]) + assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2]) + return V_H, Q_H, J + +def iqm(x): + return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None) + +NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP +G_min = -1 # the minimum possible return +G_max = 1 # the maximum possible return +nS, nA = 1442, 8 + +PROB_DIAB = 0.2 + +# Ground truth MDP model +MDP_parameters = joblib.load('../data/MDP_parameters.joblib') +P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next) +R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A) +nS, nA = R.shape +gamma = 0.99 + +# unif rand isd, mixture of diabetic state +isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib') +isd = (isd > 0).astype(float) +isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB) +isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB) + +# Precomputed optimal policy +π_star = joblib.load('../data/π_star.joblib') + + + +# ## Load data +input_dir = f'../datagen/vaso_eps_{eps_str}-100k/' + +def load_data(fname): + print('Loading data', fname, '...', end='') + df_data = pd.read_csv('{}/{}'.format(input_dir, fname)).rename(columns={'State_idx': 'State'})#[['pt_id', 'Time', 'State', 'Action', 'Reward']] + + # Assign next state + df_data['NextState'] = [*df_data['State'].iloc[1:].values, -1] + df_data.loc[df_data.groupby('pt_id')['Time'].idxmax(), 'NextState'] = -1 + df_data.loc[(df_data['Reward'] == -1), 'NextState'] = 1440 + df_data.loc[(df_data['Reward'] == 1), 'NextState'] = 1441 + + assert ((df_data['Reward'] != 0) == (df_data['Action'] == -1)).all() + + print('DONE') + return df_data + +# df_seed1 = load_data('1-features.csv') # tr +df_seed2 = load_data('2-features.csv') # va + + +# ## Policies + +# vaso unif, mv abx optimal +π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) + + +# ### Behavior policy + +# vaso eps=0.5, mv abx optimal +π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_beh[π_star == 1] = 1-eps +π_beh[π_beh == 0.5] = eps + +V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh) +J_beh + + +# ### Optimal policy +V_H_star, Q_H_star, J_star = policy_eval_helper(π_star) +J_star + + + +# ## Compare OPE + +# ### observed behavior + +df_results_0 = [] +for run in range(runs): + df_va = df_seed2.set_index('pt_id').loc[200000+run*run_idx_length:200000+run*run_idx_length + N_val - 1].reset_index()[ + ['pt_id', 'Time', 'State', 'Action', 'Reward', 'NextState'] + ] + df = df_va[['pt_id', 'Time', 'State', 'Action', 'Reward', 'NextState']] + + # OPE - WIS/WDR prep + data_va = format_data_tensor(df) + pi_b_va = compute_behavior_policy_h(df) + + # OPE - IS + IS_value, WIS_value, info = OPE_IS_h(data_va, pi_b_va, π_star, gamma, epsilon=0.0) + df_results_0.append([info['G'].mean()]) + +df_results_0 = pd.DataFrame(df_results_0, columns=['IS_value']) +df_results_0.to_csv(out_fname, index=False) diff --git a/sepsisSim/experiments/exp-1-onpolicy-baseline.py b/sepsisSim/experiments/exp-1-onpolicy-baseline.py new file mode 100644 index 0000000..efcd090 --- /dev/null +++ b/sepsisSim/experiments/exp-1-onpolicy-baseline.py @@ -0,0 +1,145 @@ +# ## Simulation parameters +exp_name = 'exp-FINAL-1' +eps = 0.10 +eps_str = '0_1' + +run_idx_length = 1_000 +N_val = 1_000 +runs = 50 + +pol_name = 'onpolicy' +out_fname = f'./results/{exp_name}/vaso_eps_{eps_str}-{pol_name}-orig.csv' + +import numpy as np +import pandas as pd + +df_tmp = None +try: + df_tmp = pd.read_csv(out_fname) +except: + pass + +if df_tmp is not None: + print('File exists') + quit() + +from tqdm import tqdm +from collections import defaultdict +import pickle +import itertools +import copy +import random +import itertools +import joblib +from joblib import Parallel, delayed + +from OPE_utils_new import ( + format_data_tensor, + policy_eval_analytic_finite, + OPE_IS_h, + compute_behavior_policy_h, +) + + +def policy_eval_helper(π): + V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H) + Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R] + J = isd @ V_H[0] + # Check recursive relationships + assert len(Q_H) == H + assert len(V_H) == H + assert np.all(Q_H[-1] == R) + assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1]) + assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2]) + return V_H, Q_H, J + +def iqm(x): + return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None) + +NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP +G_min = -1 # the minimum possible return +G_max = 1 # the maximum possible return +nS, nA = 1442, 8 + +PROB_DIAB = 0.2 + +# Ground truth MDP model +MDP_parameters = joblib.load('../data/MDP_parameters.joblib') +P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next) +R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A) +nS, nA = R.shape +gamma = 0.99 + +# unif rand isd, mixture of diabetic state +isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib') +isd = (isd > 0).astype(float) +isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB) +isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB) + +# Precomputed optimal policy +π_star = joblib.load('../data/π_star.joblib') + + + +# ## Load data +input_dir = f'../datagen/vaso_eps_{eps_str}-100k/' + +def load_data(fname): + print('Loading data', fname, '...', end='') + df_data = pd.read_csv('{}/{}'.format(input_dir, fname)).rename(columns={'State_idx': 'State'})#[['pt_id', 'Time', 'State', 'Action', 'Reward']] + + # Assign next state + df_data['NextState'] = [*df_data['State'].iloc[1:].values, -1] + df_data.loc[df_data.groupby('pt_id')['Time'].idxmax(), 'NextState'] = -1 + df_data.loc[(df_data['Reward'] == -1), 'NextState'] = 1440 + df_data.loc[(df_data['Reward'] == 1), 'NextState'] = 1441 + + assert ((df_data['Reward'] != 0) == (df_data['Action'] == -1)).all() + + print('DONE') + return df_data + + +# df_seed1 = load_data('1-features.csv') # tr +df_seed2 = load_data('2-features.csv') # va + + +# ## Policies + +# vaso unif, mv abx optimal +π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) + + +# ### Behavior policy + +# vaso eps=0.5, mv abx optimal +π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_beh[π_star == 1] = 1-eps +π_beh[π_beh == 0.5] = eps + +V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh) + +# ## Compare OPE + +π_eval = π_beh + + +# ### Baseline IS: original dataset + +df_results = [] +for run in range(runs): + df_va = df_seed2.set_index('pt_id').loc[200000+run*run_idx_length:200000+run*run_idx_length + N_val - 1].reset_index()[ + ['pt_id', 'Time', 'State', 'Action', 'Reward', 'NextState'] + ] + df = df_va[['pt_id', 'Time', 'State', 'Action', 'Reward', 'NextState']] + + # OPE - WIS/WDR prep + data_va = format_data_tensor(df) + pi_b_va = compute_behavior_policy_h(df) + + # OPE - IS + IS_value, WIS_value, ESS_info = OPE_IS_h(data_va, pi_b_va, π_eval, gamma, epsilon=0.0) + df_results.append([IS_value, WIS_value, ESS_info['ESS1'], ESS_info['ESS2']]) + +df_results = pd.DataFrame(df_results, columns=['IS_value', 'WIS_value', 'ESS1', 'ESS2']) +df_results.to_csv(out_fname, index=False) diff --git a/sepsisSim/experiments/exp-2-annotBeh.py b/sepsisSim/experiments/exp-2-annotBeh.py new file mode 100644 index 0000000..cbed2fa --- /dev/null +++ b/sepsisSim/experiments/exp-2-annotBeh.py @@ -0,0 +1,374 @@ +# ## Simulation parameters +exp_name = 'exp-FINAL-2' +eps = 0.10 +eps_str = '0_1' + +run_idx_length = 1_000 +N_val = 1_000 +runs = 50 + +# Number of action-flipped states +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--flip_num', type=int) +parser.add_argument('--flip_seed', type=int) +args = parser.parse_args() +pol_flip_num = args.flip_num +pol_flip_seed = args.flip_seed + +pol_name = f'flip{pol_flip_num}_seed{pol_flip_seed}' +out_fname = f'./results/{exp_name}/vaso_eps_{eps_str}-{pol_name}-aug_step-annotBeh.csv' + +import numpy as np +import pandas as pd + +df_tmp = None +try: + df_tmp = pd.read_csv(out_fname) +except: + pass + +if df_tmp is not None: + print('File exists') + quit() + +from tqdm import tqdm +from collections import defaultdict +import pickle +import itertools +import copy +import random +import itertools +import joblib +from joblib import Parallel, delayed + +from OPE_utils_new import ( + format_data_tensor, + policy_eval_analytic_finite, + OPE_IS_h, + compute_behavior_policy_h, +) + + +def policy_eval_helper(π): + V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H) + Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R] + J = isd @ V_H[0] + # Check recursive relationships + assert len(Q_H) == H + assert len(V_H) == H + assert np.all(Q_H[-1] == R) + assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1]) + assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2]) + return V_H, Q_H, J + +def iqm(x): + return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None) + +NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP +G_min = -1 # the minimum possible return +G_max = 1 # the maximum possible return +nS, nA = 1442, 8 + +PROB_DIAB = 0.2 + +# Ground truth MDP model +MDP_parameters = joblib.load('../data/MDP_parameters.joblib') +P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next) +R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A) +nS, nA = R.shape +gamma = 0.99 + +# unif rand isd, mixture of diabetic state +isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib') +isd = (isd > 0).astype(float) +isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB) +isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB) + +# Precomputed optimal policy +π_star = joblib.load('../data/π_star.joblib') + + + +# ## Load data +input_dir = f'../datagen/vaso_eps_{eps_str}-100k/' + +def load_data(fname): + print('Loading data', fname, '...', end='') + df_data = pd.read_csv('{}/{}'.format(input_dir, fname)).rename(columns={'State_idx': 'State'})#[['pt_id', 'Time', 'State', 'Action', 'Reward']] + + # Assign next state + df_data['NextState'] = [*df_data['State'].iloc[1:].values, -1] + df_data.loc[df_data.groupby('pt_id')['Time'].idxmax(), 'NextState'] = -1 + df_data.loc[(df_data['Reward'] == -1), 'NextState'] = 1440 + df_data.loc[(df_data['Reward'] == 1), 'NextState'] = 1441 + + assert ((df_data['Reward'] != 0) == (df_data['Action'] == -1)).all() + + print('DONE') + return df_data + + +# df_seed1 = load_data('1-features.csv') # tr +df_seed2 = load_data('2-features.csv') # va + + +# ## C-PDIS code + +def compute_augmented_behavior_policy_h(df_data): + πh_b = np.zeros((H, nS, nA)) + hsa_counts = df_data.groupby(['Time', 'State', 'Action'])[['Weight']].sum().rename(columns={'Weight': 'count'}).reset_index() + + try: + for i, row in hsa_counts.iterrows(): + h, s, a = int(row['Time']), int(row['State']), int(row['Action']) + count = row['count'] + if row['Action'] == -1: + πh_b[h, s, :] = count + else: + πh_b[h, s, a] = count + except: + print(h,s,a) + raise + # import pdb + # pdb.set_trace() + + # assume uniform action probabilities in unobserved states + unobserved_states = (πh_b.sum(axis=-1) == 0) + πh_b[unobserved_states, :] = 1 + + # normalize action probabilities + πh_b = πh_b / πh_b.sum(axis=-1, keepdims=True) + + return πh_b + + +def format_data_tensor_cf(df_data, id_col='map_pt_id'): + """ + Converts data from a dataframe to a tensor + - df_data: pd.DataFrame with columns [id_col, Time, State, Action, Reward, NextState] + - id_col specifies the index column to group episodes + - data_tensor: integer tensor of shape (N, NSTEPS, 5) with the last last dimension being [t, s, a, r, s'] + """ + data_dict = dict(list(df_data.groupby(id_col))) + N = len(data_dict) + data_tensor = np.zeros((N, 2*NSTEPS, 6), dtype=float) + data_tensor[:, :, 0] = -1 # initialize all time steps to -1 + data_tensor[:, :, 2] = -1 # initialize all actions to -1 + data_tensor[:, :, 1] = -1 # initialize all states to -1 + data_tensor[:, :, 4] = -1 # initialize all next states to -1 + data_tensor[:, :, 5] = np.nan # initialize all weights to NaN + + for i, (pt_id, df_values) in tqdm(enumerate(data_dict.items()), disable=True): + values = df_values.set_index(id_col)[['Time', 'State', 'Action', 'Reward', 'NextState', 'Weight']].values + data_tensor[i, :len(values), :] = values + return data_tensor + + +def OPE_PDIS_h(data, π_b, π_e, γ, epsilon=0.01): + """ + - π_b, π_e: behavior/evaluation policy, shape (S,A) + """ + # Get a soft version of the evaluation policy for WIS + π_e_soft = π_e.astype(float) * (1 - epsilon*2) + π_unif * epsilon*2 + + # # Get a soft version of the behavior policy for WIS + # π_b_soft = π_b * (1 - epsilon) + epsilon / nA + + # Apply WIS + return _pdis_h(data, π_b, π_e_soft, γ) + +def _pdis_h(data, π_b, π_e, γ): + # For each original trajectory + v_all, rho_all = [], [] + for i, data_i in enumerate(data): + # Get all trajectories based on this trajectory + t_l = data_i[..., 0].astype(int) + s_l = data_i[..., 1].astype(int) + a_l = data_i[..., 2].astype(int) + r_l = data_i[..., 3].astype(float) + snext_l = data_i[..., 4].astype(int) + w_l = data_i[..., 5].astype(float) + + # Per-transition importance ratios + p_b = π_b[t_l, s_l, a_l] + p_e = π_e[s_l, a_l] + + # Deal with variable length sequences by setting ratio to 1 + terminated_idx = (a_l == -1) + terminating_idx = (s_l != -1) & (a_l == -1) + p_b[terminated_idx] = np.nan + p_e[terminated_idx] = np.nan + p_b[terminating_idx] = 1 + p_e[terminating_idx] = 1 + + # Per-step cumulative importance ratios + rho_t = (p_e / p_b) + + # # Last observed step of each trajectory + # idx_last = np.array([np.max(np.nonzero(s_l[row] != -1)) for row in range(len(s_l))]) + + # Initialize value to 0, importance ratio to 1 + v = 0 + rho_cum = 1 + + # Iterate backwards from step H to 1 + for h in reversed(range(H)): + # only start computing from the last observed step + if not (t_l == h).any(): + continue + + # do we have counterfactual annotation for this step? + if (t_l == h).sum() > 1: + # if we have counterfactual annotations for this step + j_all = np.argwhere(t_l == h).ravel() + assert np.isclose(w_l[j_all].sum(), 1) # weights add up to 1 + + # Identify factual transition and counterfactual annotations + f_, cf_ = [], [] + for j in j_all: + if snext_l[j] == 1442: # counterfactual annotation have dummy next state + cf_.append(j) + else: + f_.append(j) + assert len(f_) == 1 # there should only be one factual transition + f_ = f_[0] + v = w_l[f_]*rho_t[f_]*(r_l[f_]+γ*v) + np.sum([w_l[j]*rho_t[j]*r_l[j] for j in cf_]) + rho_cum = rho_cum * (w_l[f_]*rho_t[f_]) + np.sum([w_l[j]*rho_t[j] for j in cf_]) + else: + # we don't have counterfactual annotations for this step + # there should only be one trajectory and that must be the original traj + j = (t_l == h).argmax() + assert ~np.isnan(p_e[j]) + assert w_l[j] == 1.0 + v = rho_t[j] * (r_l[j]+γ*v) + rho_cum = rho_cum * rho_t[j] + + v_all.append(v) + rho_all.append(rho_cum) + + v_all = np.array(v_all) + rho_all = np.array(rho_all) + is_value = np.nansum(v_all) / len(rho_all) + wis_value = np.nansum(v_all) / np.nansum(rho_all) + rho_norm = rho_all / np.nansum(rho_all) + rho_nna = rho_all[~np.isnan(rho_all)] + rho_norm_nna = rho_norm[~np.isnan(rho_norm)] + ess1 = 1 / np.nansum(rho_norm_nna ** 2) + ess1_ = (np.nansum(rho_nna)) ** 2 / (np.nansum(rho_nna ** 2)) + ess2 = 1. / np.nanmax(rho_norm) + return is_value, wis_value, { + 'ESS1': ess1, 'ESS2': ess2, + 'rho': rho_all, 'rho_norm': rho_norm_nna, + } + + +## Default weighting scheme for C-PDIS + +weight_a_sa = np.zeros((nS, nA, nA)) + +# default weight if no counterfactual actions +for a in range(nA): + weight_a_sa[:, a, a] = 1 + +# split equally between factual and counterfactual actions +for s in range(nS): + a = π_star.argmax(axis=1)[s] + a_tilde = a+1-2*(a%2) + weight_a_sa[s, a, a] = 0.5 + weight_a_sa[s, a, a_tilde] = 0.5 + weight_a_sa[s, a_tilde, a] = 0.5 + weight_a_sa[s, a_tilde, a_tilde] = 0.5 + +assert np.all(weight_a_sa.sum(axis=-1) == 1) + + +# ## Policies + +# vaso unif, mv abx optimal +π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) + + +# ### Behavior policy + +# vaso eps=0.5, mv abx optimal +π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_beh[π_star == 1] = 1-eps +π_beh[π_beh == 0.5] = eps + +V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh) +J_beh + + +# ### Optimal policy +V_H_star, Q_H_star, J_star = policy_eval_helper(π_star) +J_star + + +# ### flip action for x% states + +rng_flip = np.random.default_rng(pol_flip_seed) +flip_states = rng_flip.choice(range(1440), pol_flip_num, replace=False) + +π_tmp = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_flip = π_tmp.copy() +π_flip[π_tmp == 0.5] = 0 +π_flip[π_star == 1] = 1 +for s in flip_states: + π_flip[s, π_tmp[s] == 0.5] = 1 + π_flip[s, π_star[s] == 1] = 0 +assert π_flip.sum(axis=1).mean() == 1 + +# np.savetxt(f'./results/{exp_name}/policy_{pol_name}.txt', π_flip) + + +# ## Compare OPE + +π_eval = π_flip + + +# ### Proposed: replace future with the value function for the evaluation policy + +df_va_all2 = pd.read_pickle(f'results/vaso_eps_{eps_str}-annotOpt_df_seed2_aug_step.pkl') +V_H_eval, Q_H_eval, J_eval = policy_eval_helper(π_eval) + + +df_results_v2 = [] +for run in range(runs): + df_va = df_va_all2.set_index('pt_id').loc[200000+run*run_idx_length:200000+run*run_idx_length + N_val - 1+0.999].reset_index() + + df_va['map_pt_id'] = df_va['pt_id'].apply(np.floor).astype(int) + df_va = df_va.drop_duplicates(['map_pt_id', 'Time', 'State', 'Action']) \ + .sort_values(by=['map_pt_id', 'Time', 'pt_id']).reset_index(drop=True) + df_va['Weight'] = np.nan + + for i, row in tqdm(list(df_va.iterrows()), disable=True): + map_pt_id = row['map_pt_id'] + h = int(row['Time']) + s = int(row['State']) + if row['NextState'] in [1440, 1441]: + df_va.loc[i, 'Weight'] = 1.0 + elif row['NextState'] in [1442]: + a_cf = int(row['Action']) + a_f = int(a_cf+1-2*(a_cf%2)) + df_va.loc[i, 'Weight'] = weight_a_sa[s, a_f, a_cf] + df_va.loc[(df_va['map_pt_id'] == map_pt_id) + & (df_va['Time'] == h) + & (df_va['Action'] == a_f), 'Weight'] = weight_a_sa[s, a_f, a_f] + df_va.loc[i, 'Reward'] = Q_H_beh[h][s, a_cf] + else: + pass + df_va['Weight'] = df_va['Weight'].fillna(1) + + # OPE - WIS/WDR prep + v2_data_va = format_data_tensor_cf(df_va) + v2_pi_b_val = compute_augmented_behavior_policy_h(df_va) + + # OPE - WIS + v2_IS_value, v2_WIS_value, v2_ESS_info = OPE_PDIS_h(v2_data_va, v2_pi_b_val, π_eval, gamma, epsilon=0.0) + + df_results_v2.append([v2_IS_value, v2_WIS_value, v2_ESS_info['ESS1']]) + pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1']).to_csv(out_fname, index=False) + +df_results_v2 = pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1']) +df_results_v2.to_csv(out_fname, index=False) diff --git a/sepsisSim/experiments/exp-2-annotEval.py b/sepsisSim/experiments/exp-2-annotEval.py new file mode 100644 index 0000000..f5b30df --- /dev/null +++ b/sepsisSim/experiments/exp-2-annotEval.py @@ -0,0 +1,374 @@ +# ## Simulation parameters +exp_name = 'exp-FINAL-2' +eps = 0.10 +eps_str = '0_1' + +run_idx_length = 1_000 +N_val = 1_000 +runs = 50 + +# Number of action-flipped states +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--flip_num', type=int) +parser.add_argument('--flip_seed', type=int) +args = parser.parse_args() +pol_flip_num = args.flip_num +pol_flip_seed = args.flip_seed + +pol_name = f'flip{pol_flip_num}_seed{pol_flip_seed}' +out_fname = f'./results/{exp_name}/vaso_eps_{eps_str}-{pol_name}-aug_step-annotEval.csv' + +import numpy as np +import pandas as pd + +df_tmp = None +try: + df_tmp = pd.read_csv(out_fname) +except: + pass + +if df_tmp is not None: + print('File exists') + quit() + +from tqdm import tqdm +from collections import defaultdict +import pickle +import itertools +import copy +import random +import itertools +import joblib +from joblib import Parallel, delayed + +from OPE_utils_new import ( + format_data_tensor, + policy_eval_analytic_finite, + OPE_IS_h, + compute_behavior_policy_h, +) + + +def policy_eval_helper(π): + V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H) + Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R] + J = isd @ V_H[0] + # Check recursive relationships + assert len(Q_H) == H + assert len(V_H) == H + assert np.all(Q_H[-1] == R) + assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1]) + assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2]) + return V_H, Q_H, J + +def iqm(x): + return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None) + +NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP +G_min = -1 # the minimum possible return +G_max = 1 # the maximum possible return +nS, nA = 1442, 8 + +PROB_DIAB = 0.2 + +# Ground truth MDP model +MDP_parameters = joblib.load('../data/MDP_parameters.joblib') +P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next) +R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A) +nS, nA = R.shape +gamma = 0.99 + +# unif rand isd, mixture of diabetic state +isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib') +isd = (isd > 0).astype(float) +isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB) +isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB) + +# Precomputed optimal policy +π_star = joblib.load('../data/π_star.joblib') + + + +# ## Load data +input_dir = f'../datagen/vaso_eps_{eps_str}-100k/' + +def load_data(fname): + print('Loading data', fname, '...', end='') + df_data = pd.read_csv('{}/{}'.format(input_dir, fname)).rename(columns={'State_idx': 'State'})#[['pt_id', 'Time', 'State', 'Action', 'Reward']] + + # Assign next state + df_data['NextState'] = [*df_data['State'].iloc[1:].values, -1] + df_data.loc[df_data.groupby('pt_id')['Time'].idxmax(), 'NextState'] = -1 + df_data.loc[(df_data['Reward'] == -1), 'NextState'] = 1440 + df_data.loc[(df_data['Reward'] == 1), 'NextState'] = 1441 + + assert ((df_data['Reward'] != 0) == (df_data['Action'] == -1)).all() + + print('DONE') + return df_data + + +# df_seed1 = load_data('1-features.csv') # tr +df_seed2 = load_data('2-features.csv') # va + + +# ## C-PDIS code + +def compute_augmented_behavior_policy_h(df_data): + πh_b = np.zeros((H, nS, nA)) + hsa_counts = df_data.groupby(['Time', 'State', 'Action'])[['Weight']].sum().rename(columns={'Weight': 'count'}).reset_index() + + try: + for i, row in hsa_counts.iterrows(): + h, s, a = int(row['Time']), int(row['State']), int(row['Action']) + count = row['count'] + if row['Action'] == -1: + πh_b[h, s, :] = count + else: + πh_b[h, s, a] = count + except: + print(h,s,a) + raise + # import pdb + # pdb.set_trace() + + # assume uniform action probabilities in unobserved states + unobserved_states = (πh_b.sum(axis=-1) == 0) + πh_b[unobserved_states, :] = 1 + + # normalize action probabilities + πh_b = πh_b / πh_b.sum(axis=-1, keepdims=True) + + return πh_b + + +def format_data_tensor_cf(df_data, id_col='map_pt_id'): + """ + Converts data from a dataframe to a tensor + - df_data: pd.DataFrame with columns [id_col, Time, State, Action, Reward, NextState] + - id_col specifies the index column to group episodes + - data_tensor: integer tensor of shape (N, NSTEPS, 5) with the last last dimension being [t, s, a, r, s'] + """ + data_dict = dict(list(df_data.groupby(id_col))) + N = len(data_dict) + data_tensor = np.zeros((N, 2*NSTEPS, 6), dtype=float) + data_tensor[:, :, 0] = -1 # initialize all time steps to -1 + data_tensor[:, :, 2] = -1 # initialize all actions to -1 + data_tensor[:, :, 1] = -1 # initialize all states to -1 + data_tensor[:, :, 4] = -1 # initialize all next states to -1 + data_tensor[:, :, 5] = np.nan # initialize all weights to NaN + + for i, (pt_id, df_values) in tqdm(enumerate(data_dict.items()), disable=True): + values = df_values.set_index(id_col)[['Time', 'State', 'Action', 'Reward', 'NextState', 'Weight']].values + data_tensor[i, :len(values), :] = values + return data_tensor + + +def OPE_PDIS_h(data, π_b, π_e, γ, epsilon=0.01): + """ + - π_b, π_e: behavior/evaluation policy, shape (S,A) + """ + # Get a soft version of the evaluation policy for WIS + π_e_soft = π_e.astype(float) * (1 - epsilon*2) + π_unif * epsilon*2 + + # # Get a soft version of the behavior policy for WIS + # π_b_soft = π_b * (1 - epsilon) + epsilon / nA + + # Apply WIS + return _pdis_h(data, π_b, π_e_soft, γ) + +def _pdis_h(data, π_b, π_e, γ): + # For each original trajectory + v_all, rho_all = [], [] + for i, data_i in enumerate(data): + # Get all trajectories based on this trajectory + t_l = data_i[..., 0].astype(int) + s_l = data_i[..., 1].astype(int) + a_l = data_i[..., 2].astype(int) + r_l = data_i[..., 3].astype(float) + snext_l = data_i[..., 4].astype(int) + w_l = data_i[..., 5].astype(float) + + # Per-transition importance ratios + p_b = π_b[t_l, s_l, a_l] + p_e = π_e[s_l, a_l] + + # Deal with variable length sequences by setting ratio to 1 + terminated_idx = (a_l == -1) + terminating_idx = (s_l != -1) & (a_l == -1) + p_b[terminated_idx] = np.nan + p_e[terminated_idx] = np.nan + p_b[terminating_idx] = 1 + p_e[terminating_idx] = 1 + + # Per-step cumulative importance ratios + rho_t = (p_e / p_b) + + # # Last observed step of each trajectory + # idx_last = np.array([np.max(np.nonzero(s_l[row] != -1)) for row in range(len(s_l))]) + + # Initialize value to 0, importance ratio to 1 + v = 0 + rho_cum = 1 + + # Iterate backwards from step H to 1 + for h in reversed(range(H)): + # only start computing from the last observed step + if not (t_l == h).any(): + continue + + # do we have counterfactual annotation for this step? + if (t_l == h).sum() > 1: + # if we have counterfactual annotations for this step + j_all = np.argwhere(t_l == h).ravel() + assert np.isclose(w_l[j_all].sum(), 1) # weights add up to 1 + + # Identify factual transition and counterfactual annotations + f_, cf_ = [], [] + for j in j_all: + if snext_l[j] == 1442: # counterfactual annotation have dummy next state + cf_.append(j) + else: + f_.append(j) + assert len(f_) == 1 # there should only be one factual transition + f_ = f_[0] + v = w_l[f_]*rho_t[f_]*(r_l[f_]+γ*v) + np.sum([w_l[j]*rho_t[j]*r_l[j] for j in cf_]) + rho_cum = rho_cum * (w_l[f_]*rho_t[f_]) + np.sum([w_l[j]*rho_t[j] for j in cf_]) + else: + # we don't have counterfactual annotations for this step + # there should only be one trajectory and that must be the original traj + j = (t_l == h).argmax() + assert ~np.isnan(p_e[j]) + assert w_l[j] == 1.0 + v = rho_t[j] * (r_l[j]+γ*v) + rho_cum = rho_cum * rho_t[j] + + v_all.append(v) + rho_all.append(rho_cum) + + v_all = np.array(v_all) + rho_all = np.array(rho_all) + is_value = np.nansum(v_all) / len(rho_all) + wis_value = np.nansum(v_all) / np.nansum(rho_all) + rho_norm = rho_all / np.nansum(rho_all) + rho_nna = rho_all[~np.isnan(rho_all)] + rho_norm_nna = rho_norm[~np.isnan(rho_norm)] + ess1 = 1 / np.nansum(rho_norm_nna ** 2) + ess1_ = (np.nansum(rho_nna)) ** 2 / (np.nansum(rho_nna ** 2)) + ess2 = 1. / np.nanmax(rho_norm) + return is_value, wis_value, { + 'ESS1': ess1, 'ESS2': ess2, + 'rho': rho_all, 'rho_norm': rho_norm_nna, + } + + +## Default weighting scheme for C-PDIS + +weight_a_sa = np.zeros((nS, nA, nA)) + +# default weight if no counterfactual actions +for a in range(nA): + weight_a_sa[:, a, a] = 1 + +# split equally between factual and counterfactual actions +for s in range(nS): + a = π_star.argmax(axis=1)[s] + a_tilde = a+1-2*(a%2) + weight_a_sa[s, a, a] = 0.5 + weight_a_sa[s, a, a_tilde] = 0.5 + weight_a_sa[s, a_tilde, a] = 0.5 + weight_a_sa[s, a_tilde, a_tilde] = 0.5 + +assert np.all(weight_a_sa.sum(axis=-1) == 1) + + +# ## Policies + +# vaso unif, mv abx optimal +π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) + + +# ### Behavior policy + +# vaso eps=0.5, mv abx optimal +π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_beh[π_star == 1] = 1-eps +π_beh[π_beh == 0.5] = eps + +V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh) +J_beh + + +# ### Optimal policy +V_H_star, Q_H_star, J_star = policy_eval_helper(π_star) +J_star + + +# ### flip action for x% states + +rng_flip = np.random.default_rng(pol_flip_seed) +flip_states = rng_flip.choice(range(1440), pol_flip_num, replace=False) + +π_tmp = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_flip = π_tmp.copy() +π_flip[π_tmp == 0.5] = 0 +π_flip[π_star == 1] = 1 +for s in flip_states: + π_flip[s, π_tmp[s] == 0.5] = 1 + π_flip[s, π_star[s] == 1] = 0 +assert π_flip.sum(axis=1).mean() == 1 + +# np.savetxt(f'./results/{exp_name}/policy_{pol_name}.txt', π_flip) + + +# ## Compare OPE + +π_eval = π_flip + + +# ### Proposed: replace future with the value function for the evaluation policy + +df_va_all2 = pd.read_pickle(f'results/vaso_eps_{eps_str}-annotOpt_df_seed2_aug_step.pkl') +V_H_eval, Q_H_eval, J_eval = policy_eval_helper(π_eval) + + +df_results_v2 = [] +for run in range(runs): + df_va = df_va_all2.set_index('pt_id').loc[200000+run*run_idx_length:200000+run*run_idx_length + N_val - 1+0.999].reset_index() + + df_va['map_pt_id'] = df_va['pt_id'].apply(np.floor).astype(int) + df_va = df_va.drop_duplicates(['map_pt_id', 'Time', 'State', 'Action']) \ + .sort_values(by=['map_pt_id', 'Time', 'pt_id']).reset_index(drop=True) + df_va['Weight'] = np.nan + + for i, row in tqdm(list(df_va.iterrows()), disable=True): + map_pt_id = row['map_pt_id'] + h = int(row['Time']) + s = int(row['State']) + if row['NextState'] in [1440, 1441]: + df_va.loc[i, 'Weight'] = 1.0 + elif row['NextState'] in [1442]: + a_cf = int(row['Action']) + a_f = int(a_cf+1-2*(a_cf%2)) + df_va.loc[i, 'Weight'] = weight_a_sa[s, a_f, a_cf] + df_va.loc[(df_va['map_pt_id'] == map_pt_id) + & (df_va['Time'] == h) + & (df_va['Action'] == a_f), 'Weight'] = weight_a_sa[s, a_f, a_f] + df_va.loc[i, 'Reward'] = Q_H_eval[h][s, a_cf] + else: + pass + df_va['Weight'] = df_va['Weight'].fillna(1) + + # OPE - WIS/WDR prep + v2_data_va = format_data_tensor_cf(df_va) + v2_pi_b_val = compute_augmented_behavior_policy_h(df_va) + + # OPE - WIS + v2_IS_value, v2_WIS_value, v2_ESS_info = OPE_PDIS_h(v2_data_va, v2_pi_b_val, π_eval, gamma, epsilon=0.0) + + df_results_v2.append([v2_IS_value, v2_WIS_value, v2_ESS_info['ESS1']]) + pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1']).to_csv(out_fname, index=False) + +df_results_v2 = pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1']) +df_results_v2.to_csv(out_fname, index=False) diff --git a/sepsisSim/experiments/exp-2-annotZero.py b/sepsisSim/experiments/exp-2-annotZero.py new file mode 100644 index 0000000..e53cad5 --- /dev/null +++ b/sepsisSim/experiments/exp-2-annotZero.py @@ -0,0 +1,374 @@ +# ## Simulation parameters +exp_name = 'exp-FINAL-2' +eps = 0.10 +eps_str = '0_1' + +run_idx_length = 1_000 +N_val = 1_000 +runs = 50 + +# Number of action-flipped states +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--flip_num', type=int) +parser.add_argument('--flip_seed', type=int) +args = parser.parse_args() +pol_flip_num = args.flip_num +pol_flip_seed = args.flip_seed + +pol_name = f'flip{pol_flip_num}_seed{pol_flip_seed}' +out_fname = f'./results/{exp_name}/vaso_eps_{eps_str}-{pol_name}-aug_step-annotZero.csv' + +import numpy as np +import pandas as pd + +df_tmp = None +try: + df_tmp = pd.read_csv(out_fname) +except: + pass + +if df_tmp is not None: + print('File exists') + quit() + +from tqdm import tqdm +from collections import defaultdict +import pickle +import itertools +import copy +import random +import itertools +import joblib +from joblib import Parallel, delayed + +from OPE_utils_new import ( + format_data_tensor, + policy_eval_analytic_finite, + OPE_IS_h, + compute_behavior_policy_h, +) + + +def policy_eval_helper(π): + V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H) + Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R] + J = isd @ V_H[0] + # Check recursive relationships + assert len(Q_H) == H + assert len(V_H) == H + assert np.all(Q_H[-1] == R) + assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1]) + assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2]) + return V_H, Q_H, J + +def iqm(x): + return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None) + +NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP +G_min = -1 # the minimum possible return +G_max = 1 # the maximum possible return +nS, nA = 1442, 8 + +PROB_DIAB = 0.2 + +# Ground truth MDP model +MDP_parameters = joblib.load('../data/MDP_parameters.joblib') +P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next) +R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A) +nS, nA = R.shape +gamma = 0.99 + +# unif rand isd, mixture of diabetic state +isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib') +isd = (isd > 0).astype(float) +isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB) +isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB) + +# Precomputed optimal policy +π_star = joblib.load('../data/π_star.joblib') + + + +# ## Load data +input_dir = f'../datagen/vaso_eps_{eps_str}-100k/' + +def load_data(fname): + print('Loading data', fname, '...', end='') + df_data = pd.read_csv('{}/{}'.format(input_dir, fname)).rename(columns={'State_idx': 'State'})#[['pt_id', 'Time', 'State', 'Action', 'Reward']] + + # Assign next state + df_data['NextState'] = [*df_data['State'].iloc[1:].values, -1] + df_data.loc[df_data.groupby('pt_id')['Time'].idxmax(), 'NextState'] = -1 + df_data.loc[(df_data['Reward'] == -1), 'NextState'] = 1440 + df_data.loc[(df_data['Reward'] == 1), 'NextState'] = 1441 + + assert ((df_data['Reward'] != 0) == (df_data['Action'] == -1)).all() + + print('DONE') + return df_data + + +# df_seed1 = load_data('1-features.csv') # tr +df_seed2 = load_data('2-features.csv') # va + + +# ## C-PDIS code + +def compute_augmented_behavior_policy_h(df_data): + πh_b = np.zeros((H, nS, nA)) + hsa_counts = df_data.groupby(['Time', 'State', 'Action'])[['Weight']].sum().rename(columns={'Weight': 'count'}).reset_index() + + try: + for i, row in hsa_counts.iterrows(): + h, s, a = int(row['Time']), int(row['State']), int(row['Action']) + count = row['count'] + if row['Action'] == -1: + πh_b[h, s, :] = count + else: + πh_b[h, s, a] = count + except: + print(h,s,a) + raise + # import pdb + # pdb.set_trace() + + # assume uniform action probabilities in unobserved states + unobserved_states = (πh_b.sum(axis=-1) == 0) + πh_b[unobserved_states, :] = 1 + + # normalize action probabilities + πh_b = πh_b / πh_b.sum(axis=-1, keepdims=True) + + return πh_b + + +def format_data_tensor_cf(df_data, id_col='map_pt_id'): + """ + Converts data from a dataframe to a tensor + - df_data: pd.DataFrame with columns [id_col, Time, State, Action, Reward, NextState] + - id_col specifies the index column to group episodes + - data_tensor: integer tensor of shape (N, NSTEPS, 5) with the last last dimension being [t, s, a, r, s'] + """ + data_dict = dict(list(df_data.groupby(id_col))) + N = len(data_dict) + data_tensor = np.zeros((N, 2*NSTEPS, 6), dtype=float) + data_tensor[:, :, 0] = -1 # initialize all time steps to -1 + data_tensor[:, :, 2] = -1 # initialize all actions to -1 + data_tensor[:, :, 1] = -1 # initialize all states to -1 + data_tensor[:, :, 4] = -1 # initialize all next states to -1 + data_tensor[:, :, 5] = np.nan # initialize all weights to NaN + + for i, (pt_id, df_values) in tqdm(enumerate(data_dict.items()), disable=True): + values = df_values.set_index(id_col)[['Time', 'State', 'Action', 'Reward', 'NextState', 'Weight']].values + data_tensor[i, :len(values), :] = values + return data_tensor + + +def OPE_PDIS_h(data, π_b, π_e, γ, epsilon=0.01): + """ + - π_b, π_e: behavior/evaluation policy, shape (S,A) + """ + # Get a soft version of the evaluation policy for WIS + π_e_soft = π_e.astype(float) * (1 - epsilon*2) + π_unif * epsilon*2 + + # # Get a soft version of the behavior policy for WIS + # π_b_soft = π_b * (1 - epsilon) + epsilon / nA + + # Apply WIS + return _pdis_h(data, π_b, π_e_soft, γ) + +def _pdis_h(data, π_b, π_e, γ): + # For each original trajectory + v_all, rho_all = [], [] + for i, data_i in enumerate(data): + # Get all trajectories based on this trajectory + t_l = data_i[..., 0].astype(int) + s_l = data_i[..., 1].astype(int) + a_l = data_i[..., 2].astype(int) + r_l = data_i[..., 3].astype(float) + snext_l = data_i[..., 4].astype(int) + w_l = data_i[..., 5].astype(float) + + # Per-transition importance ratios + p_b = π_b[t_l, s_l, a_l] + p_e = π_e[s_l, a_l] + + # Deal with variable length sequences by setting ratio to 1 + terminated_idx = (a_l == -1) + terminating_idx = (s_l != -1) & (a_l == -1) + p_b[terminated_idx] = np.nan + p_e[terminated_idx] = np.nan + p_b[terminating_idx] = 1 + p_e[terminating_idx] = 1 + + # Per-step cumulative importance ratios + rho_t = (p_e / p_b) + + # # Last observed step of each trajectory + # idx_last = np.array([np.max(np.nonzero(s_l[row] != -1)) for row in range(len(s_l))]) + + # Initialize value to 0, importance ratio to 1 + v = 0 + rho_cum = 1 + + # Iterate backwards from step H to 1 + for h in reversed(range(H)): + # only start computing from the last observed step + if not (t_l == h).any(): + continue + + # do we have counterfactual annotation for this step? + if (t_l == h).sum() > 1: + # if we have counterfactual annotations for this step + j_all = np.argwhere(t_l == h).ravel() + assert np.isclose(w_l[j_all].sum(), 1) # weights add up to 1 + + # Identify factual transition and counterfactual annotations + f_, cf_ = [], [] + for j in j_all: + if snext_l[j] == 1442: # counterfactual annotation have dummy next state + cf_.append(j) + else: + f_.append(j) + assert len(f_) == 1 # there should only be one factual transition + f_ = f_[0] + v = w_l[f_]*rho_t[f_]*(r_l[f_]+γ*v) + np.sum([w_l[j]*rho_t[j]*r_l[j] for j in cf_]) + rho_cum = rho_cum * (w_l[f_]*rho_t[f_]) + np.sum([w_l[j]*rho_t[j] for j in cf_]) + else: + # we don't have counterfactual annotations for this step + # there should only be one trajectory and that must be the original traj + j = (t_l == h).argmax() + assert ~np.isnan(p_e[j]) + assert w_l[j] == 1.0 + v = rho_t[j] * (r_l[j]+γ*v) + rho_cum = rho_cum * rho_t[j] + + v_all.append(v) + rho_all.append(rho_cum) + + v_all = np.array(v_all) + rho_all = np.array(rho_all) + is_value = np.nansum(v_all) / len(rho_all) + wis_value = np.nansum(v_all) / np.nansum(rho_all) + rho_norm = rho_all / np.nansum(rho_all) + rho_nna = rho_all[~np.isnan(rho_all)] + rho_norm_nna = rho_norm[~np.isnan(rho_norm)] + ess1 = 1 / np.nansum(rho_norm_nna ** 2) + ess1_ = (np.nansum(rho_nna)) ** 2 / (np.nansum(rho_nna ** 2)) + ess2 = 1. / np.nanmax(rho_norm) + return is_value, wis_value, { + 'ESS1': ess1, 'ESS2': ess2, + 'rho': rho_all, 'rho_norm': rho_norm_nna, + } + + +## Default weighting scheme for C-PDIS + +weight_a_sa = np.zeros((nS, nA, nA)) + +# default weight if no counterfactual actions +for a in range(nA): + weight_a_sa[:, a, a] = 1 + +# split equally between factual and counterfactual actions +for s in range(nS): + a = π_star.argmax(axis=1)[s] + a_tilde = a+1-2*(a%2) + weight_a_sa[s, a, a] = 0.5 + weight_a_sa[s, a, a_tilde] = 0.5 + weight_a_sa[s, a_tilde, a] = 0.5 + weight_a_sa[s, a_tilde, a_tilde] = 0.5 + +assert np.all(weight_a_sa.sum(axis=-1) == 1) + + +# ## Policies + +# vaso unif, mv abx optimal +π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) + + +# ### Behavior policy + +# vaso eps=0.5, mv abx optimal +π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_beh[π_star == 1] = 1-eps +π_beh[π_beh == 0.5] = eps + +V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh) +J_beh + + +# ### Optimal policy +V_H_star, Q_H_star, J_star = policy_eval_helper(π_star) +J_star + + +# ### flip action for x% states + +rng_flip = np.random.default_rng(pol_flip_seed) +flip_states = rng_flip.choice(range(1440), pol_flip_num, replace=False) + +π_tmp = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_flip = π_tmp.copy() +π_flip[π_tmp == 0.5] = 0 +π_flip[π_star == 1] = 1 +for s in flip_states: + π_flip[s, π_tmp[s] == 0.5] = 1 + π_flip[s, π_star[s] == 1] = 0 +assert π_flip.sum(axis=1).mean() == 1 + +# np.savetxt(f'./results/{exp_name}/policy_{pol_name}.txt', π_flip) + + +# ## Compare OPE + +π_eval = π_flip + + +# ### Proposed: replace future with the value function for the evaluation policy + +df_va_all2 = pd.read_pickle(f'results/vaso_eps_{eps_str}-annotOpt_df_seed2_aug_step.pkl') +V_H_eval, Q_H_eval, J_eval = policy_eval_helper(π_eval) + + +df_results_v2 = [] +for run in range(runs): + df_va = df_va_all2.set_index('pt_id').loc[200000+run*run_idx_length:200000+run*run_idx_length + N_val - 1+0.999].reset_index() + + df_va['map_pt_id'] = df_va['pt_id'].apply(np.floor).astype(int) + df_va = df_va.drop_duplicates(['map_pt_id', 'Time', 'State', 'Action']) \ + .sort_values(by=['map_pt_id', 'Time', 'pt_id']).reset_index(drop=True) + df_va['Weight'] = np.nan + + for i, row in tqdm(list(df_va.iterrows()), disable=True): + map_pt_id = row['map_pt_id'] + h = int(row['Time']) + s = int(row['State']) + if row['NextState'] in [1440, 1441]: + df_va.loc[i, 'Weight'] = 1.0 + elif row['NextState'] in [1442]: + a_cf = int(row['Action']) + a_f = int(a_cf+1-2*(a_cf%2)) + df_va.loc[i, 'Weight'] = weight_a_sa[s, a_f, a_cf] + df_va.loc[(df_va['map_pt_id'] == map_pt_id) + & (df_va['Time'] == h) + & (df_va['Action'] == a_f), 'Weight'] = weight_a_sa[s, a_f, a_f] + df_va.loc[i, 'Reward'] = 0 + else: + pass + df_va['Weight'] = df_va['Weight'].fillna(1) + + # OPE - WIS/WDR prep + v2_data_va = format_data_tensor_cf(df_va) + v2_pi_b_val = compute_augmented_behavior_policy_h(df_va) + + # OPE - WIS + v2_IS_value, v2_WIS_value, v2_ESS_info = OPE_PDIS_h(v2_data_va, v2_pi_b_val, π_eval, gamma, epsilon=0.0) + + df_results_v2.append([v2_IS_value, v2_WIS_value, v2_ESS_info['ESS1']]) + pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1']).to_csv(out_fname, index=False) + +df_results_v2 = pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1']) +df_results_v2.to_csv(out_fname, index=False) diff --git a/sepsisSim/experiments/exp-2-onpolicy-annotBeh.py b/sepsisSim/experiments/exp-2-onpolicy-annotBeh.py new file mode 100644 index 0000000..89082a0 --- /dev/null +++ b/sepsisSim/experiments/exp-2-onpolicy-annotBeh.py @@ -0,0 +1,341 @@ +# ## Simulation parameters +exp_name = 'exp-FINAL-2' +eps = 0.10 +eps_str = '0_1' + +run_idx_length = 1_000 +N_val = 1_000 +runs = 50 + +pol_name = 'onpolicy' +out_fname = f'./results/{exp_name}/vaso_eps_{eps_str}-{pol_name}-aug_step-annotBeh.csv' + +import numpy as np +import pandas as pd + +df_tmp = None +try: + df_tmp = pd.read_csv(out_fname) +except: + pass + +if df_tmp is not None: + print('File exists') + quit() + +from tqdm import tqdm +from collections import defaultdict +import pickle +import itertools +import copy +import random +import itertools +import joblib +from joblib import Parallel, delayed + +from OPE_utils_new import ( + format_data_tensor, + policy_eval_analytic_finite, + OPE_IS_h, + compute_behavior_policy_h, +) + + +def policy_eval_helper(π): + V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H) + Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R] + J = isd @ V_H[0] + # Check recursive relationships + assert len(Q_H) == H + assert len(V_H) == H + assert np.all(Q_H[-1] == R) + assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1]) + assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2]) + return V_H, Q_H, J + +def iqm(x): + return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None) + +NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP +G_min = -1 # the minimum possible return +G_max = 1 # the maximum possible return +nS, nA = 1442, 8 + +PROB_DIAB = 0.2 + +# Ground truth MDP model +MDP_parameters = joblib.load('../data/MDP_parameters.joblib') +P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next) +R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A) +nS, nA = R.shape +gamma = 0.99 + +# unif rand isd, mixture of diabetic state +isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib') +isd = (isd > 0).astype(float) +isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB) +isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB) + +# Precomputed optimal policy +π_star = joblib.load('../data/π_star.joblib') + + + +# ## Load data +input_dir = f'../datagen/vaso_eps_{eps_str}-100k/' + +def load_data(fname): + print('Loading data', fname, '...', end='') + df_data = pd.read_csv('{}/{}'.format(input_dir, fname)).rename(columns={'State_idx': 'State'})#[['pt_id', 'Time', 'State', 'Action', 'Reward']] + + # Assign next state + df_data['NextState'] = [*df_data['State'].iloc[1:].values, -1] + df_data.loc[df_data.groupby('pt_id')['Time'].idxmax(), 'NextState'] = -1 + df_data.loc[(df_data['Reward'] == -1), 'NextState'] = 1440 + df_data.loc[(df_data['Reward'] == 1), 'NextState'] = 1441 + + assert ((df_data['Reward'] != 0) == (df_data['Action'] == -1)).all() + + print('DONE') + return df_data + + +# df_seed1 = load_data('1-features.csv') # tr +df_seed2 = load_data('2-features.csv') # va + + +# ## C-PDIS code + +def compute_augmented_behavior_policy_h(df_data): + πh_b = np.zeros((H, nS, nA)) + hsa_counts = df_data.groupby(['Time', 'State', 'Action'])[['Weight']].sum().rename(columns={'Weight': 'count'}).reset_index() + + try: + for i, row in hsa_counts.iterrows(): + h, s, a = int(row['Time']), int(row['State']), int(row['Action']) + count = row['count'] + if row['Action'] == -1: + πh_b[h, s, :] = count + else: + πh_b[h, s, a] = count + except: + print(h,s,a) + raise + # import pdb + # pdb.set_trace() + + # assume uniform action probabilities in unobserved states + unobserved_states = (πh_b.sum(axis=-1) == 0) + πh_b[unobserved_states, :] = 1 + + # normalize action probabilities + πh_b = πh_b / πh_b.sum(axis=-1, keepdims=True) + + return πh_b + + +def format_data_tensor_cf(df_data, id_col='map_pt_id'): + """ + Converts data from a dataframe to a tensor + - df_data: pd.DataFrame with columns [id_col, Time, State, Action, Reward, NextState] + - id_col specifies the index column to group episodes + - data_tensor: integer tensor of shape (N, NSTEPS, 5) with the last last dimension being [t, s, a, r, s'] + """ + data_dict = dict(list(df_data.groupby(id_col))) + N = len(data_dict) + data_tensor = np.zeros((N, 2*NSTEPS, 6), dtype=float) + data_tensor[:, :, 0] = -1 # initialize all time steps to -1 + data_tensor[:, :, 2] = -1 # initialize all actions to -1 + data_tensor[:, :, 1] = -1 # initialize all states to -1 + data_tensor[:, :, 4] = -1 # initialize all next states to -1 + data_tensor[:, :, 5] = np.nan # initialize all weights to NaN + + for i, (pt_id, df_values) in tqdm(enumerate(data_dict.items()), disable=True): + values = df_values.set_index(id_col)[['Time', 'State', 'Action', 'Reward', 'NextState', 'Weight']].values + data_tensor[i, :len(values), :] = values + return data_tensor + + +def OPE_PDIS_h(data, π_b, π_e, γ, epsilon=0.01): + """ + - π_b, π_e: behavior/evaluation policy, shape (S,A) + """ + # Get a soft version of the evaluation policy for WIS + π_e_soft = π_e.astype(float) * (1 - epsilon*2) + π_unif * epsilon*2 + + # # Get a soft version of the behavior policy for WIS + # π_b_soft = π_b * (1 - epsilon) + epsilon / nA + + # Apply WIS + return _pdis_h(data, π_b, π_e_soft, γ) + +def _pdis_h(data, π_b, π_e, γ): + # For each original trajectory + v_all, rho_all = [], [] + for i, data_i in enumerate(data): + # Get all trajectories based on this trajectory + t_l = data_i[..., 0].astype(int) + s_l = data_i[..., 1].astype(int) + a_l = data_i[..., 2].astype(int) + r_l = data_i[..., 3].astype(float) + snext_l = data_i[..., 4].astype(int) + w_l = data_i[..., 5].astype(float) + + # Per-transition importance ratios + p_b = π_b[t_l, s_l, a_l] + p_e = π_e[s_l, a_l] + + # Deal with variable length sequences by setting ratio to 1 + terminated_idx = (a_l == -1) + terminating_idx = (s_l != -1) & (a_l == -1) + p_b[terminated_idx] = np.nan + p_e[terminated_idx] = np.nan + p_b[terminating_idx] = 1 + p_e[terminating_idx] = 1 + + # Per-step cumulative importance ratios + rho_t = (p_e / p_b) + + # # Last observed step of each trajectory + # idx_last = np.array([np.max(np.nonzero(s_l[row] != -1)) for row in range(len(s_l))]) + + # Initialize value to 0, importance ratio to 1 + v = 0 + rho_cum = 1 + + # Iterate backwards from step H to 1 + for h in reversed(range(H)): + # only start computing from the last observed step + if not (t_l == h).any(): + continue + + # do we have counterfactual annotation for this step? + if (t_l == h).sum() > 1: + # if we have counterfactual annotations for this step + j_all = np.argwhere(t_l == h).ravel() + assert np.isclose(w_l[j_all].sum(), 1) # weights add up to 1 + + # Identify factual transition and counterfactual annotations + f_, cf_ = [], [] + for j in j_all: + if snext_l[j] == 1442: # counterfactual annotation have dummy next state + cf_.append(j) + else: + f_.append(j) + assert len(f_) == 1 # there should only be one factual transition + f_ = f_[0] + v = w_l[f_]*rho_t[f_]*(r_l[f_]+γ*v) + np.sum([w_l[j]*rho_t[j]*r_l[j] for j in cf_]) + rho_cum = rho_cum * (w_l[f_]*rho_t[f_]) + np.sum([w_l[j]*rho_t[j] for j in cf_]) + else: + # we don't have counterfactual annotations for this step + # there should only be one trajectory and that must be the original traj + j = (t_l == h).argmax() + assert ~np.isnan(p_e[j]) + assert w_l[j] == 1.0 + v = rho_t[j] * (r_l[j]+γ*v) + rho_cum = rho_cum * rho_t[j] + + v_all.append(v) + rho_all.append(rho_cum) + + v_all = np.array(v_all) + rho_all = np.array(rho_all) + is_value = np.nansum(v_all) / len(rho_all) + wis_value = np.nansum(v_all) / np.nansum(rho_all) + rho_norm = rho_all / np.nansum(rho_all) + rho_nna = rho_all[~np.isnan(rho_all)] + rho_norm_nna = rho_norm[~np.isnan(rho_norm)] + ess1 = 1 / np.nansum(rho_norm_nna ** 2) + ess1_ = (np.nansum(rho_nna)) ** 2 / (np.nansum(rho_nna ** 2)) + ess2 = 1. / np.nanmax(rho_norm) + return is_value, wis_value, { + 'ESS1': ess1, 'ESS2': ess2, + 'rho': rho_all, 'rho_norm': rho_norm_nna, + } + + +## Default weighting scheme for C-PDIS + +weight_a_sa = np.zeros((nS, nA, nA)) + +# default weight if no counterfactual actions +for a in range(nA): + weight_a_sa[:, a, a] = 1 + +# split equally between factual and counterfactual actions +for s in range(nS): + a = π_star.argmax(axis=1)[s] + a_tilde = a+1-2*(a%2) + weight_a_sa[s, a, a] = 0.5 + weight_a_sa[s, a, a_tilde] = 0.5 + weight_a_sa[s, a_tilde, a] = 0.5 + weight_a_sa[s, a_tilde, a_tilde] = 0.5 + +assert np.all(weight_a_sa.sum(axis=-1) == 1) + + +# ## Policies + +# vaso unif, mv abx optimal +π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) + + +# ### Behavior policy + +# vaso eps=0.5, mv abx optimal +π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_beh[π_star == 1] = 1-eps +π_beh[π_beh == 0.5] = eps + +V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh) + +# ## Compare OPE + +π_eval = π_beh + + +# ### Proposed: replace future with the value function for the evaluation policy + +df_va_all2 = pd.read_pickle(f'results/vaso_eps_{eps_str}-annotOpt_df_seed2_aug_step.pkl') +V_H_eval, Q_H_eval, J_eval = policy_eval_helper(π_eval) + + +df_results_v2 = [] +for run in range(runs): + df_va = df_va_all2.set_index('pt_id').loc[200000+run*run_idx_length:200000+run*run_idx_length + N_val - 1+0.999].reset_index() + + df_va['map_pt_id'] = df_va['pt_id'].apply(np.floor).astype(int) + df_va = df_va.drop_duplicates(['map_pt_id', 'Time', 'State', 'Action']) \ + .sort_values(by=['map_pt_id', 'Time', 'pt_id']).reset_index(drop=True) + df_va['Weight'] = np.nan + + for i, row in tqdm(list(df_va.iterrows()), disable=True): + map_pt_id = row['map_pt_id'] + h = int(row['Time']) + s = int(row['State']) + if row['NextState'] in [1440, 1441]: + df_va.loc[i, 'Weight'] = 1.0 + elif row['NextState'] in [1442]: + a_cf = int(row['Action']) + a_f = int(a_cf+1-2*(a_cf%2)) + df_va.loc[i, 'Weight'] = weight_a_sa[s, a_f, a_cf] + df_va.loc[(df_va['map_pt_id'] == map_pt_id) + & (df_va['Time'] == h) + & (df_va['Action'] == a_f), 'Weight'] = weight_a_sa[s, a_f, a_f] + df_va.loc[i, 'Reward'] = Q_H_beh[h][s, a_cf] + else: + pass + df_va['Weight'] = df_va['Weight'].fillna(1) + + # OPE - WIS/WDR prep + v2_data_va = format_data_tensor_cf(df_va) + v2_pi_b_val = compute_augmented_behavior_policy_h(df_va) + + # OPE - WIS + v2_IS_value, v2_WIS_value, v2_ESS_info = OPE_PDIS_h(v2_data_va, v2_pi_b_val, π_eval, gamma, epsilon=0.0) + + df_results_v2.append([v2_IS_value, v2_WIS_value, v2_ESS_info['ESS1']]) + pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1']).to_csv(out_fname, index=False) + +df_results_v2 = pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1']) +df_results_v2.to_csv(out_fname, index=False) diff --git a/sepsisSim/experiments/exp-3-annotEvalNoise.py b/sepsisSim/experiments/exp-3-annotEvalNoise.py new file mode 100644 index 0000000..359742b --- /dev/null +++ b/sepsisSim/experiments/exp-3-annotEvalNoise.py @@ -0,0 +1,378 @@ +# ## Simulation parameters +exp_name = 'exp-FINAL-3' +eps = 0.10 +eps_str = '0_1' + +run_idx_length = 1_000 +N_val = 1_000 +runs = 50 + +# Number of action-flipped states +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--flip_num', type=int) +parser.add_argument('--flip_seed', type=int) +parser.add_argument('--annot_noise', type=float) +args = parser.parse_args() +pol_flip_num = args.flip_num +pol_flip_seed = args.flip_seed +annot_noise = args.annot_noise + +pol_name = f'flip{pol_flip_num}_seed{pol_flip_seed}' +out_fname = f'./results/{exp_name}/vaso_eps_{eps_str}-{pol_name}-aug_step-annotEval-Noise_{annot_noise}.csv' + +import numpy as np +import pandas as pd + +df_tmp = None +try: + df_tmp = pd.read_csv(out_fname) +except: + pass + +if df_tmp is not None: + print('File exists') + quit() + +from tqdm import tqdm +from collections import defaultdict +import pickle +import itertools +import copy +import random +import itertools +import joblib +from joblib import Parallel, delayed + +from OPE_utils_new import ( + format_data_tensor, + policy_eval_analytic_finite, + OPE_IS_h, + compute_behavior_policy_h, +) + + +def policy_eval_helper(π): + V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H) + Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R] + J = isd @ V_H[0] + # Check recursive relationships + assert len(Q_H) == H + assert len(V_H) == H + assert np.all(Q_H[-1] == R) + assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1]) + assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2]) + return V_H, Q_H, J + +def iqm(x): + return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None) + +NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP +G_min = -1 # the minimum possible return +G_max = 1 # the maximum possible return +nS, nA = 1442, 8 + +PROB_DIAB = 0.2 + +# Ground truth MDP model +MDP_parameters = joblib.load('../data/MDP_parameters.joblib') +P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next) +R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A) +nS, nA = R.shape +gamma = 0.99 + +# unif rand isd, mixture of diabetic state +isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib') +isd = (isd > 0).astype(float) +isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB) +isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB) + +# Precomputed optimal policy +π_star = joblib.load('../data/π_star.joblib') + + + +# ## Load data +input_dir = f'../datagen/vaso_eps_{eps_str}-100k/' + +def load_data(fname): + print('Loading data', fname, '...', end='') + df_data = pd.read_csv('{}/{}'.format(input_dir, fname)).rename(columns={'State_idx': 'State'})#[['pt_id', 'Time', 'State', 'Action', 'Reward']] + + # Assign next state + df_data['NextState'] = [*df_data['State'].iloc[1:].values, -1] + df_data.loc[df_data.groupby('pt_id')['Time'].idxmax(), 'NextState'] = -1 + df_data.loc[(df_data['Reward'] == -1), 'NextState'] = 1440 + df_data.loc[(df_data['Reward'] == 1), 'NextState'] = 1441 + + assert ((df_data['Reward'] != 0) == (df_data['Action'] == -1)).all() + + print('DONE') + return df_data + + +# df_seed1 = load_data('1-features.csv') # tr +df_seed2 = load_data('2-features.csv') # va + + +# ## C-PDIS code + +def compute_augmented_behavior_policy_h(df_data): + πh_b = np.zeros((H, nS, nA)) + hsa_counts = df_data.groupby(['Time', 'State', 'Action'])[['Weight']].sum().rename(columns={'Weight': 'count'}).reset_index() + + try: + for i, row in hsa_counts.iterrows(): + h, s, a = int(row['Time']), int(row['State']), int(row['Action']) + count = row['count'] + if row['Action'] == -1: + πh_b[h, s, :] = count + else: + πh_b[h, s, a] = count + except: + print(h,s,a) + raise + # import pdb + # pdb.set_trace() + + # assume uniform action probabilities in unobserved states + unobserved_states = (πh_b.sum(axis=-1) == 0) + πh_b[unobserved_states, :] = 1 + + # normalize action probabilities + πh_b = πh_b / πh_b.sum(axis=-1, keepdims=True) + + return πh_b + + +def format_data_tensor_cf(df_data, id_col='map_pt_id'): + """ + Converts data from a dataframe to a tensor + - df_data: pd.DataFrame with columns [id_col, Time, State, Action, Reward, NextState] + - id_col specifies the index column to group episodes + - data_tensor: integer tensor of shape (N, NSTEPS, 5) with the last last dimension being [t, s, a, r, s'] + """ + data_dict = dict(list(df_data.groupby(id_col))) + N = len(data_dict) + data_tensor = np.zeros((N, 2*NSTEPS, 6), dtype=float) + data_tensor[:, :, 0] = -1 # initialize all time steps to -1 + data_tensor[:, :, 2] = -1 # initialize all actions to -1 + data_tensor[:, :, 1] = -1 # initialize all states to -1 + data_tensor[:, :, 4] = -1 # initialize all next states to -1 + data_tensor[:, :, 5] = np.nan # initialize all weights to NaN + + for i, (pt_id, df_values) in tqdm(enumerate(data_dict.items()), disable=True): + values = df_values.set_index(id_col)[['Time', 'State', 'Action', 'Reward', 'NextState', 'Weight']].values + data_tensor[i, :len(values), :] = values + return data_tensor + + +def OPE_PDIS_h(data, π_b, π_e, γ, epsilon=0.01): + """ + - π_b, π_e: behavior/evaluation policy, shape (S,A) + """ + # Get a soft version of the evaluation policy for WIS + π_e_soft = π_e.astype(float) * (1 - epsilon*2) + π_unif * epsilon*2 + + # # Get a soft version of the behavior policy for WIS + # π_b_soft = π_b * (1 - epsilon) + epsilon / nA + + # Apply WIS + return _pdis_h(data, π_b, π_e_soft, γ) + +def _pdis_h(data, π_b, π_e, γ): + # For each original trajectory + v_all, rho_all = [], [] + for i, data_i in enumerate(data): + # Get all trajectories based on this trajectory + t_l = data_i[..., 0].astype(int) + s_l = data_i[..., 1].astype(int) + a_l = data_i[..., 2].astype(int) + r_l = data_i[..., 3].astype(float) + snext_l = data_i[..., 4].astype(int) + w_l = data_i[..., 5].astype(float) + + # Per-transition importance ratios + p_b = π_b[t_l, s_l, a_l] + p_e = π_e[s_l, a_l] + + # Deal with variable length sequences by setting ratio to 1 + terminated_idx = (a_l == -1) + terminating_idx = (s_l != -1) & (a_l == -1) + p_b[terminated_idx] = np.nan + p_e[terminated_idx] = np.nan + p_b[terminating_idx] = 1 + p_e[terminating_idx] = 1 + + # Per-step cumulative importance ratios + rho_t = (p_e / p_b) + + # # Last observed step of each trajectory + # idx_last = np.array([np.max(np.nonzero(s_l[row] != -1)) for row in range(len(s_l))]) + + # Initialize value to 0, importance ratio to 1 + v = 0 + rho_cum = 1 + + # Iterate backwards from step H to 1 + for h in reversed(range(H)): + # only start computing from the last observed step + if not (t_l == h).any(): + continue + + # do we have counterfactual annotation for this step? + if (t_l == h).sum() > 1: + # if we have counterfactual annotations for this step + j_all = np.argwhere(t_l == h).ravel() + assert np.isclose(w_l[j_all].sum(), 1) # weights add up to 1 + + # Identify factual transition and counterfactual annotations + f_, cf_ = [], [] + for j in j_all: + if snext_l[j] == 1442: # counterfactual annotation have dummy next state + cf_.append(j) + else: + f_.append(j) + assert len(f_) == 1 # there should only be one factual transition + f_ = f_[0] + v = w_l[f_]*rho_t[f_]*(r_l[f_]+γ*v) + np.sum([w_l[j]*rho_t[j]*r_l[j] for j in cf_]) + rho_cum = rho_cum * (w_l[f_]*rho_t[f_]) + np.sum([w_l[j]*rho_t[j] for j in cf_]) + else: + # we don't have counterfactual annotations for this step + # there should only be one trajectory and that must be the original traj + j = (t_l == h).argmax() + assert ~np.isnan(p_e[j]) + assert w_l[j] == 1.0 + v = rho_t[j] * (r_l[j]+γ*v) + rho_cum = rho_cum * rho_t[j] + + v_all.append(v) + rho_all.append(rho_cum) + + v_all = np.array(v_all) + rho_all = np.array(rho_all) + is_value = np.nansum(v_all) / len(rho_all) + wis_value = np.nansum(v_all) / np.nansum(rho_all) + rho_norm = rho_all / np.nansum(rho_all) + rho_nna = rho_all[~np.isnan(rho_all)] + rho_norm_nna = rho_norm[~np.isnan(rho_norm)] + ess1 = 1 / np.nansum(rho_norm_nna ** 2) + ess1_ = (np.nansum(rho_nna)) ** 2 / (np.nansum(rho_nna ** 2)) + ess2 = 1. / np.nanmax(rho_norm) + return is_value, wis_value, { + 'ESS1': ess1, 'ESS2': ess2, + 'rho': rho_all, 'rho_norm': rho_norm_nna, + } + + +## Default weighting scheme for C-PDIS + +weight_a_sa = np.zeros((nS, nA, nA)) + +# default weight if no counterfactual actions +for a in range(nA): + weight_a_sa[:, a, a] = 1 + +# split equally between factual and counterfactual actions +for s in range(nS): + a = π_star.argmax(axis=1)[s] + a_tilde = a+1-2*(a%2) + weight_a_sa[s, a, a] = 0.5 + weight_a_sa[s, a, a_tilde] = 0.5 + weight_a_sa[s, a_tilde, a] = 0.5 + weight_a_sa[s, a_tilde, a_tilde] = 0.5 + +assert np.all(weight_a_sa.sum(axis=-1) == 1) + + +# ## Policies + +# vaso unif, mv abx optimal +π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) + + +# ### Behavior policy + +# vaso eps=0.5, mv abx optimal +π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_beh[π_star == 1] = 1-eps +π_beh[π_beh == 0.5] = eps + +V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh) +J_beh + + +# ### Optimal policy +V_H_star, Q_H_star, J_star = policy_eval_helper(π_star) +J_star + + +# ### flip action for x% states + +rng_flip = np.random.default_rng(pol_flip_seed) +flip_states = rng_flip.choice(range(1440), pol_flip_num, replace=False) + +π_tmp = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_flip = π_tmp.copy() +π_flip[π_tmp == 0.5] = 0 +π_flip[π_star == 1] = 1 +for s in flip_states: + π_flip[s, π_tmp[s] == 0.5] = 1 + π_flip[s, π_star[s] == 1] = 0 +assert π_flip.sum(axis=1).mean() == 1 + +# np.savetxt(f'./results/{exp_name}/policy_{pol_name}.txt', π_flip) + + +# ## Compare OPE + +π_eval = π_flip + + +# ### Proposed: replace future with the value function for the evaluation policy +rng_annot = np.random.default_rng(seed=123456789) + + +df_va_all2 = pd.read_pickle(f'results/vaso_eps_{eps_str}-annotOpt_df_seed2_aug_step.pkl') +V_H_eval, Q_H_eval, J_eval = policy_eval_helper(π_eval) + + +df_results_v2 = [] +for run in range(runs): + df_va = df_va_all2.set_index('pt_id').loc[200000+run*run_idx_length:200000+run*run_idx_length + N_val - 1+0.999].reset_index() + + df_va['map_pt_id'] = df_va['pt_id'].apply(np.floor).astype(int) + df_va = df_va.drop_duplicates(['map_pt_id', 'Time', 'State', 'Action']) \ + .sort_values(by=['map_pt_id', 'Time', 'pt_id']).reset_index(drop=True) + df_va['Weight'] = np.nan + + for i, row in tqdm(list(df_va.iterrows()), disable=True): + map_pt_id = row['map_pt_id'] + h = int(row['Time']) + s = int(row['State']) + if row['NextState'] in [1440, 1441]: + df_va.loc[i, 'Weight'] = 1.0 + elif row['NextState'] in [1442]: + a_cf = int(row['Action']) + a_f = int(a_cf+1-2*(a_cf%2)) + df_va.loc[i, 'Weight'] = weight_a_sa[s, a_f, a_cf] + df_va.loc[(df_va['map_pt_id'] == map_pt_id) + & (df_va['Time'] == h) + & (df_va['Action'] == a_f), 'Weight'] = weight_a_sa[s, a_f, a_f] + df_va.loc[i, 'Reward'] = Q_H_eval[h][s, a_cf] + rng_annot.normal(0, annot_noise) + else: + pass + df_va['Weight'] = df_va['Weight'].fillna(1) + + # OPE - WIS/WDR prep + v2_data_va = format_data_tensor_cf(df_va) + v2_pi_b_val = compute_augmented_behavior_policy_h(df_va) + + # OPE - WIS + v2_IS_value, v2_WIS_value, v2_ESS_info = OPE_PDIS_h(v2_data_va, v2_pi_b_val, π_eval, gamma, epsilon=0.0) + + df_results_v2.append([v2_IS_value, v2_WIS_value, v2_ESS_info['ESS1']]) + pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1']).to_csv(out_fname, index=False) + +df_results_v2 = pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1']) +df_results_v2.to_csv(out_fname, index=False) diff --git a/sepsisSim/experiments/exp-4-annotEvalMissing.py b/sepsisSim/experiments/exp-4-annotEvalMissing.py new file mode 100644 index 0000000..d1d98c3 --- /dev/null +++ b/sepsisSim/experiments/exp-4-annotEvalMissing.py @@ -0,0 +1,398 @@ +# ## Simulation parameters +exp_name = 'exp-FINAL-4' +eps = 0.10 +eps_str = '0_1' + +run_idx_length = 1_000 +N_val = 1_000 +runs = 50 + +# Number of action-flipped states +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--flip_num', type=int) +parser.add_argument('--flip_seed', type=int) +parser.add_argument('--annot_noise', type=float, default=0.2) +parser.add_argument('--annot_ratio', type=float, default=1.0) +args = parser.parse_args() +pol_flip_num = args.flip_num +pol_flip_seed = args.flip_seed +annot_noise = args.annot_noise +annot_ratio = args.annot_ratio + +pol_name = f'flip{pol_flip_num}_seed{pol_flip_seed}' +out_fname = f'./results/{exp_name}/vaso_eps_{eps_str}-{pol_name}-aug_step-annotEval-Noise_{annot_noise}-Missing_{annot_ratio}.csv' + +import numpy as np +import pandas as pd + +df_tmp = None +try: + df_tmp = pd.read_csv(out_fname) +except: + pass + +if df_tmp is not None: + print('File exists') + quit() + +from tqdm import tqdm +from collections import defaultdict +import pickle +import itertools +import copy +import random +import itertools +import joblib +from joblib import Parallel, delayed + +from OPE_utils_new import ( + format_data_tensor, + policy_eval_analytic_finite, + OPE_IS_h, + compute_behavior_policy_h, +) + + +def policy_eval_helper(π): + V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H) + Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R] + J = isd @ V_H[0] + # Check recursive relationships + assert len(Q_H) == H + assert len(V_H) == H + assert np.all(Q_H[-1] == R) + assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1]) + assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2]) + return V_H, Q_H, J + +def iqm(x): + return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None) + +NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP +G_min = -1 # the minimum possible return +G_max = 1 # the maximum possible return +nS, nA = 1442, 8 + +PROB_DIAB = 0.2 + +# Ground truth MDP model +MDP_parameters = joblib.load('../data/MDP_parameters.joblib') +P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next) +R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A) +nS, nA = R.shape +gamma = 0.99 + +# unif rand isd, mixture of diabetic state +isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib') +isd = (isd > 0).astype(float) +isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB) +isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB) + +# Precomputed optimal policy +π_star = joblib.load('../data/π_star.joblib') + + + +# ## Load data +input_dir = f'../datagen/vaso_eps_{eps_str}-100k/' + +def load_data(fname): + print('Loading data', fname, '...', end='') + df_data = pd.read_csv('{}/{}'.format(input_dir, fname)).rename(columns={'State_idx': 'State'})#[['pt_id', 'Time', 'State', 'Action', 'Reward']] + + # Assign next state + df_data['NextState'] = [*df_data['State'].iloc[1:].values, -1] + df_data.loc[df_data.groupby('pt_id')['Time'].idxmax(), 'NextState'] = -1 + df_data.loc[(df_data['Reward'] == -1), 'NextState'] = 1440 + df_data.loc[(df_data['Reward'] == 1), 'NextState'] = 1441 + + assert ((df_data['Reward'] != 0) == (df_data['Action'] == -1)).all() + + print('DONE') + return df_data + + +# df_seed1 = load_data('1-features.csv') # tr +df_seed2 = load_data('2-features.csv') # va + + +# ## C-PDIS code + +def compute_augmented_behavior_policy_h(df_data): + πh_b = np.zeros((H, nS, nA)) + hsa_counts = df_data.groupby(['Time', 'State', 'Action'])[['Weight']].sum().rename(columns={'Weight': 'count'}).reset_index() + + try: + for i, row in hsa_counts.iterrows(): + h, s, a = int(row['Time']), int(row['State']), int(row['Action']) + count = row['count'] + if row['Action'] == -1: + πh_b[h, s, :] = count + else: + πh_b[h, s, a] = count + except: + print(h,s,a) + raise + # import pdb + # pdb.set_trace() + + # assume uniform action probabilities in unobserved states + unobserved_states = (πh_b.sum(axis=-1) == 0) + πh_b[unobserved_states, :] = 1 + + # normalize action probabilities + πh_b = πh_b / πh_b.sum(axis=-1, keepdims=True) + + return πh_b + + +def format_data_tensor_cf(df_data, id_col='map_pt_id'): + """ + Converts data from a dataframe to a tensor + - df_data: pd.DataFrame with columns [id_col, Time, State, Action, Reward, NextState] + - id_col specifies the index column to group episodes + - data_tensor: integer tensor of shape (N, NSTEPS, 5) with the last last dimension being [t, s, a, r, s'] + """ + data_dict = dict(list(df_data.groupby(id_col))) + N = len(data_dict) + data_tensor = np.zeros((N, 2*NSTEPS, 6), dtype=float) + data_tensor[:, :, 0] = -1 # initialize all time steps to -1 + data_tensor[:, :, 2] = -1 # initialize all actions to -1 + data_tensor[:, :, 1] = -1 # initialize all states to -1 + data_tensor[:, :, 4] = -1 # initialize all next states to -1 + data_tensor[:, :, 5] = np.nan # initialize all weights to NaN + + for i, (pt_id, df_values) in tqdm(enumerate(data_dict.items()), disable=True): + values = df_values.set_index(id_col)[['Time', 'State', 'Action', 'Reward', 'NextState', 'Weight']].values + data_tensor[i, :len(values), :] = values + return data_tensor + + +def OPE_PDIS_h(data, π_b, π_e, γ, epsilon=0.01): + """ + - π_b, π_e: behavior/evaluation policy, shape (S,A) + """ + # Get a soft version of the evaluation policy for WIS + π_e_soft = π_e.astype(float) * (1 - epsilon*2) + π_unif * epsilon*2 + + # # Get a soft version of the behavior policy for WIS + # π_b_soft = π_b * (1 - epsilon) + epsilon / nA + + # Apply WIS + return _pdis_h(data, π_b, π_e_soft, γ) + +def _pdis_h(data, π_b, π_e, γ): + # For each original trajectory + v_all, rho_all = [], [] + for i, data_i in enumerate(data): + # Get all trajectories based on this trajectory + t_l = data_i[..., 0].astype(int) + s_l = data_i[..., 1].astype(int) + a_l = data_i[..., 2].astype(int) + r_l = data_i[..., 3].astype(float) + snext_l = data_i[..., 4].astype(int) + w_l = data_i[..., 5].astype(float) + + # Per-transition importance ratios + p_b = π_b[t_l, s_l, a_l] + p_e = π_e[s_l, a_l] + + # Deal with variable length sequences by setting ratio to 1 + terminated_idx = (a_l == -1) + terminating_idx = (s_l != -1) & (a_l == -1) + p_b[terminated_idx] = np.nan + p_e[terminated_idx] = np.nan + p_b[terminating_idx] = 1 + p_e[terminating_idx] = 1 + + # Per-step cumulative importance ratios + rho_t = (p_e / p_b) + + # # Last observed step of each trajectory + # idx_last = np.array([np.max(np.nonzero(s_l[row] != -1)) for row in range(len(s_l))]) + + # Initialize value to 0, importance ratio to 1 + v = 0 + rho_cum = 1 + + # Iterate backwards from step H to 1 + for h in reversed(range(H)): + # only start computing from the last observed step + if not (t_l == h).any(): + continue + + # do we have counterfactual annotation for this step? + if (t_l == h).sum() > 1: + # if we have counterfactual annotations for this step + j_all = np.argwhere(t_l == h).ravel() + assert np.isclose(w_l[j_all].sum(), 1) # weights add up to 1 + + # Identify factual transition and counterfactual annotations + f_, cf_ = [], [] + for j in j_all: + if snext_l[j] == 1442: # counterfactual annotation have dummy next state + cf_.append(j) + else: + f_.append(j) + assert len(f_) == 1 # there should only be one factual transition + f_ = f_[0] + v = w_l[f_]*rho_t[f_]*(r_l[f_]+γ*v) + np.sum([w_l[j]*rho_t[j]*r_l[j] for j in cf_]) + rho_cum = rho_cum * (w_l[f_]*rho_t[f_]) + np.sum([w_l[j]*rho_t[j] for j in cf_]) + else: + # we don't have counterfactual annotations for this step + # there should only be one trajectory and that must be the original traj + j = (t_l == h).argmax() + assert ~np.isnan(p_e[j]) + assert w_l[j] == 1.0 + v = rho_t[j] * (r_l[j]+γ*v) + rho_cum = rho_cum * rho_t[j] + + v_all.append(v) + rho_all.append(rho_cum) + + v_all = np.array(v_all) + rho_all = np.array(rho_all) + is_value = np.nansum(v_all) / len(rho_all) + wis_value = np.nansum(v_all) / np.nansum(rho_all) + rho_norm = rho_all / np.nansum(rho_all) + rho_nna = rho_all[~np.isnan(rho_all)] + rho_norm_nna = rho_norm[~np.isnan(rho_norm)] + ess1 = 1 / np.nansum(rho_norm_nna ** 2) + ess1_ = (np.nansum(rho_nna)) ** 2 / (np.nansum(rho_nna ** 2)) + ess2 = 1. / np.nanmax(rho_norm) + return is_value, wis_value, { + 'ESS1': ess1, 'ESS2': ess2, + 'rho': rho_all, 'rho_norm': rho_norm_nna, + } + + +## Default weighting scheme for C-PDIS + +weight_a_sa = np.zeros((nS, nA, nA)) + +# default weight if no counterfactual actions +for a in range(nA): + weight_a_sa[:, a, a] = 1 + +# split equally between factual and counterfactual actions +for s in range(nS): + a = π_star.argmax(axis=1)[s] + a_tilde = a+1-2*(a%2) + weight_a_sa[s, a, a] = 0.5 + weight_a_sa[s, a, a_tilde] = 0.5 + weight_a_sa[s, a_tilde, a] = 0.5 + weight_a_sa[s, a_tilde, a_tilde] = 0.5 + +assert np.all(weight_a_sa.sum(axis=-1) == 1) + + +# ## Policies + +# vaso unif, mv abx optimal +π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) + + +# ### Behavior policy + +# vaso eps=0.5, mv abx optimal +π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_beh[π_star == 1] = 1-eps +π_beh[π_beh == 0.5] = eps + +V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh) +J_beh + + +# ### Optimal policy +V_H_star, Q_H_star, J_star = policy_eval_helper(π_star) +J_star + + +# ### flip action for x% states + +rng_flip = np.random.default_rng(pol_flip_seed) +flip_states = rng_flip.choice(range(1440), pol_flip_num, replace=False) + +π_tmp = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_flip = π_tmp.copy() +π_flip[π_tmp == 0.5] = 0 +π_flip[π_star == 1] = 1 +for s in flip_states: + π_flip[s, π_tmp[s] == 0.5] = 1 + π_flip[s, π_star[s] == 1] = 0 +assert π_flip.sum(axis=1).mean() == 1 + +# np.savetxt(f'./results/{exp_name}/policy_{pol_name}.txt', π_flip) + + +# ## Compare OPE + +π_eval = π_flip + + +# ### Proposed: replace future with the value function for the evaluation policy +rng_annot = np.random.default_rng(seed=123456789) +rng_cf = np.random.default_rng(42424242) + + +df_va_all2 = pd.read_pickle(f'results/vaso_eps_{eps_str}-annotOpt_df_seed2_aug_step.pkl') +V_H_eval, Q_H_eval, J_eval = policy_eval_helper(π_eval) + + +df_results_v2 = [] +for run in range(runs): + df_va = df_va_all2.set_index('pt_id').loc[200000+run*run_idx_length:200000+run*run_idx_length + N_val - 1+0.999].reset_index() + + df_va['map_pt_id'] = df_va['pt_id'].apply(np.floor).astype(int) + df_va = df_va.drop_duplicates(['map_pt_id', 'Time', 'State', 'Action']) \ + .sort_values(by=['map_pt_id', 'Time', 'pt_id']).reset_index(drop=True) + df_va['Weight'] = np.nan + + # only use the cf annot with probability pcf + keep_mask = [] + for i, row in tqdm(list(df_va.iterrows()), disable=True): + map_pt_id = row['map_pt_id'] + h = int(row['Time']) + s = int(row['State']) + if row['NextState'] in [1440, 1441]: + keep_mask.append(True) + elif row['NextState'] in [1442]: + if rng_cf.uniform() < annot_ratio: + keep_mask.append(True) + else: + keep_mask.append(False) + else: + keep_mask.append(True) + df_va = df_va.loc[keep_mask] + + for i, row in tqdm(list(df_va.iterrows()), disable=True): + map_pt_id = row['map_pt_id'] + h = int(row['Time']) + s = int(row['State']) + if row['NextState'] in [1440, 1441]: + df_va.loc[i, 'Weight'] = 1.0 + elif row['NextState'] in [1442]: + a_cf = int(row['Action']) + a_f = int(a_cf+1-2*(a_cf%2)) + df_va.loc[i, 'Weight'] = weight_a_sa[s, a_f, a_cf] + df_va.loc[(df_va['map_pt_id'] == map_pt_id) + & (df_va['Time'] == h) + & (df_va['Action'] == a_f), 'Weight'] = weight_a_sa[s, a_f, a_f] + df_va.loc[i, 'Reward'] = Q_H_eval[h][s, a_cf] + rng_annot.normal(0, annot_noise) + else: + pass + df_va['Weight'] = df_va['Weight'].fillna(1) + + # OPE - WIS/WDR prep + v2_data_va = format_data_tensor_cf(df_va) + v2_pi_b_val = compute_augmented_behavior_policy_h(df_va) + + # OPE - WIS + v2_IS_value, v2_WIS_value, v2_ESS_info = OPE_PDIS_h(v2_data_va, v2_pi_b_val, π_eval, gamma, epsilon=0.0) + + df_results_v2.append([v2_IS_value, v2_WIS_value, v2_ESS_info['ESS1']]) + pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1']).to_csv(out_fname, index=False) + +df_results_v2 = pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1']) +df_results_v2.to_csv(out_fname, index=False) diff --git a/sepsisSim/experiments/exp-4-annotEvalMissingImpute.py b/sepsisSim/experiments/exp-4-annotEvalMissingImpute.py new file mode 100644 index 0000000..7476e42 --- /dev/null +++ b/sepsisSim/experiments/exp-4-annotEvalMissingImpute.py @@ -0,0 +1,418 @@ +# ## Simulation parameters +exp_name = 'exp-FINAL-4' +eps = 0.10 +eps_str = '0_1' + +run_idx_length = 1_000 +N_val = 1_000 +runs = 50 + +# Number of action-flipped states +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--flip_num', type=int) +parser.add_argument('--flip_seed', type=int) +parser.add_argument('--annot_noise', type=float, default=0.2) +parser.add_argument('--annot_ratio', type=float, default=1.0) +args = parser.parse_args() +pol_flip_num = args.flip_num +pol_flip_seed = args.flip_seed +annot_noise = args.annot_noise +annot_ratio = args.annot_ratio + +pol_name = f'flip{pol_flip_num}_seed{pol_flip_seed}' +out_fname = f'./results/{exp_name}/vaso_eps_{eps_str}-{pol_name}-aug_step-annotEval-Noise_{annot_noise}-MissingImpute_{annot_ratio}.csv' + +import numpy as np +import pandas as pd + +df_tmp = None +try: + df_tmp = pd.read_csv(out_fname) +except: + pass + +if df_tmp is not None: + print('File exists') + quit() + +from tqdm import tqdm +from collections import defaultdict +import pickle +import itertools +import copy +import random +import itertools +import joblib +from joblib import Parallel, delayed + +from OPE_utils_new import ( + format_data_tensor, + policy_eval_analytic_finite, + OPE_IS_h, + compute_behavior_policy_h, +) + + +def policy_eval_helper(π): + V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H) + Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R] + J = isd @ V_H[0] + # Check recursive relationships + assert len(Q_H) == H + assert len(V_H) == H + assert np.all(Q_H[-1] == R) + assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1]) + assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2]) + return V_H, Q_H, J + +def iqm(x): + return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None) + +NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP +G_min = -1 # the minimum possible return +G_max = 1 # the maximum possible return +nS, nA = 1442, 8 + +PROB_DIAB = 0.2 + +# Ground truth MDP model +MDP_parameters = joblib.load('../data/MDP_parameters.joblib') +P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next) +R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A) +nS, nA = R.shape +gamma = 0.99 + +# unif rand isd, mixture of diabetic state +isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib') +isd = (isd > 0).astype(float) +isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB) +isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB) + +# Precomputed optimal policy +π_star = joblib.load('../data/π_star.joblib') + + + +# ## Load data +input_dir = f'../datagen/vaso_eps_{eps_str}-100k/' + +def load_data(fname): + print('Loading data', fname, '...', end='') + df_data = pd.read_csv('{}/{}'.format(input_dir, fname)).rename(columns={'State_idx': 'State'})#[['pt_id', 'Time', 'State', 'Action', 'Reward']] + + # Assign next state + df_data['NextState'] = [*df_data['State'].iloc[1:].values, -1] + df_data.loc[df_data.groupby('pt_id')['Time'].idxmax(), 'NextState'] = -1 + df_data.loc[(df_data['Reward'] == -1), 'NextState'] = 1440 + df_data.loc[(df_data['Reward'] == 1), 'NextState'] = 1441 + + assert ((df_data['Reward'] != 0) == (df_data['Action'] == -1)).all() + + print('DONE') + return df_data + + +# df_seed1 = load_data('1-features.csv') # tr +df_seed2 = load_data('2-features.csv') # va + + +# ## C-PDIS code + +def compute_augmented_behavior_policy_h(df_data): + πh_b = np.zeros((H, nS, nA)) + hsa_counts = df_data.groupby(['Time', 'State', 'Action'])[['Weight']].sum().rename(columns={'Weight': 'count'}).reset_index() + + try: + for i, row in hsa_counts.iterrows(): + h, s, a = int(row['Time']), int(row['State']), int(row['Action']) + count = row['count'] + if row['Action'] == -1: + πh_b[h, s, :] = count + else: + πh_b[h, s, a] = count + except: + print(h,s,a) + raise + # import pdb + # pdb.set_trace() + + # assume uniform action probabilities in unobserved states + unobserved_states = (πh_b.sum(axis=-1) == 0) + πh_b[unobserved_states, :] = 1 + + # normalize action probabilities + πh_b = πh_b / πh_b.sum(axis=-1, keepdims=True) + + return πh_b + + +def format_data_tensor_cf(df_data, id_col='map_pt_id'): + """ + Converts data from a dataframe to a tensor + - df_data: pd.DataFrame with columns [id_col, Time, State, Action, Reward, NextState] + - id_col specifies the index column to group episodes + - data_tensor: integer tensor of shape (N, NSTEPS, 5) with the last last dimension being [t, s, a, r, s'] + """ + data_dict = dict(list(df_data.groupby(id_col))) + N = len(data_dict) + data_tensor = np.zeros((N, 2*NSTEPS, 6), dtype=float) + data_tensor[:, :, 0] = -1 # initialize all time steps to -1 + data_tensor[:, :, 2] = -1 # initialize all actions to -1 + data_tensor[:, :, 1] = -1 # initialize all states to -1 + data_tensor[:, :, 4] = -1 # initialize all next states to -1 + data_tensor[:, :, 5] = np.nan # initialize all weights to NaN + + for i, (pt_id, df_values) in tqdm(enumerate(data_dict.items()), disable=True): + values = df_values.set_index(id_col)[['Time', 'State', 'Action', 'Reward', 'NextState', 'Weight']].values + data_tensor[i, :len(values), :] = values + return data_tensor + + +def OPE_PDIS_h(data, π_b, π_e, γ, epsilon=0.01): + """ + - π_b, π_e: behavior/evaluation policy, shape (S,A) + """ + # Get a soft version of the evaluation policy for WIS + π_e_soft = π_e.astype(float) * (1 - epsilon*2) + π_unif * epsilon*2 + + # # Get a soft version of the behavior policy for WIS + # π_b_soft = π_b * (1 - epsilon) + epsilon / nA + + # Apply WIS + return _pdis_h(data, π_b, π_e_soft, γ) + +def _pdis_h(data, π_b, π_e, γ): + # For each original trajectory + v_all, rho_all = [], [] + for i, data_i in enumerate(data): + # Get all trajectories based on this trajectory + t_l = data_i[..., 0].astype(int) + s_l = data_i[..., 1].astype(int) + a_l = data_i[..., 2].astype(int) + r_l = data_i[..., 3].astype(float) + snext_l = data_i[..., 4].astype(int) + w_l = data_i[..., 5].astype(float) + + # Per-transition importance ratios + p_b = π_b[t_l, s_l, a_l] + p_e = π_e[s_l, a_l] + + # Deal with variable length sequences by setting ratio to 1 + terminated_idx = (a_l == -1) + terminating_idx = (s_l != -1) & (a_l == -1) + p_b[terminated_idx] = np.nan + p_e[terminated_idx] = np.nan + p_b[terminating_idx] = 1 + p_e[terminating_idx] = 1 + + # Per-step cumulative importance ratios + rho_t = (p_e / p_b) + + # # Last observed step of each trajectory + # idx_last = np.array([np.max(np.nonzero(s_l[row] != -1)) for row in range(len(s_l))]) + + # Initialize value to 0, importance ratio to 1 + v = 0 + rho_cum = 1 + + # Iterate backwards from step H to 1 + for h in reversed(range(H)): + # only start computing from the last observed step + if not (t_l == h).any(): + continue + + # do we have counterfactual annotation for this step? + if (t_l == h).sum() > 1: + # if we have counterfactual annotations for this step + j_all = np.argwhere(t_l == h).ravel() + assert np.isclose(w_l[j_all].sum(), 1) # weights add up to 1 + + # Identify factual transition and counterfactual annotations + f_, cf_ = [], [] + for j in j_all: + if snext_l[j] == 1442: # counterfactual annotation have dummy next state + cf_.append(j) + else: + f_.append(j) + assert len(f_) == 1 # there should only be one factual transition + f_ = f_[0] + v = w_l[f_]*rho_t[f_]*(r_l[f_]+γ*v) + np.sum([w_l[j]*rho_t[j]*r_l[j] for j in cf_]) + rho_cum = rho_cum * (w_l[f_]*rho_t[f_]) + np.sum([w_l[j]*rho_t[j] for j in cf_]) + else: + # we don't have counterfactual annotations for this step + # there should only be one trajectory and that must be the original traj + j = (t_l == h).argmax() + assert ~np.isnan(p_e[j]) + assert w_l[j] == 1.0 + v = rho_t[j] * (r_l[j]+γ*v) + rho_cum = rho_cum * rho_t[j] + + v_all.append(v) + rho_all.append(rho_cum) + + v_all = np.array(v_all) + rho_all = np.array(rho_all) + is_value = np.nansum(v_all) / len(rho_all) + wis_value = np.nansum(v_all) / np.nansum(rho_all) + rho_norm = rho_all / np.nansum(rho_all) + rho_nna = rho_all[~np.isnan(rho_all)] + rho_norm_nna = rho_norm[~np.isnan(rho_norm)] + ess1 = 1 / np.nansum(rho_norm_nna ** 2) + ess1_ = (np.nansum(rho_nna)) ** 2 / (np.nansum(rho_nna ** 2)) + ess2 = 1. / np.nanmax(rho_norm) + return is_value, wis_value, { + 'ESS1': ess1, 'ESS2': ess2, + 'rho': rho_all, 'rho_norm': rho_norm_nna, + } + + +## Default weighting scheme for C-PDIS + +weight_a_sa = np.zeros((nS, nA, nA)) + +# default weight if no counterfactual actions +for a in range(nA): + weight_a_sa[:, a, a] = 1 + +# split equally between factual and counterfactual actions +for s in range(nS): + a = π_star.argmax(axis=1)[s] + a_tilde = a+1-2*(a%2) + weight_a_sa[s, a, a] = 0.5 + weight_a_sa[s, a, a_tilde] = 0.5 + weight_a_sa[s, a_tilde, a] = 0.5 + weight_a_sa[s, a_tilde, a_tilde] = 0.5 + +assert np.all(weight_a_sa.sum(axis=-1) == 1) + + +# ## Policies + +# vaso unif, mv abx optimal +π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) + + +# ### Behavior policy + +# vaso eps=0.5, mv abx optimal +π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_beh[π_star == 1] = 1-eps +π_beh[π_beh == 0.5] = eps + +V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh) +J_beh + + +# ### Optimal policy +V_H_star, Q_H_star, J_star = policy_eval_helper(π_star) +J_star + + +# ### flip action for x% states + +rng_flip = np.random.default_rng(pol_flip_seed) +flip_states = rng_flip.choice(range(1440), pol_flip_num, replace=False) + +π_tmp = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_flip = π_tmp.copy() +π_flip[π_tmp == 0.5] = 0 +π_flip[π_star == 1] = 1 +for s in flip_states: + π_flip[s, π_tmp[s] == 0.5] = 1 + π_flip[s, π_star[s] == 1] = 0 +assert π_flip.sum(axis=1).mean() == 1 + +# np.savetxt(f'./results/{exp_name}/policy_{pol_name}.txt', π_flip) + + +# ## Compare OPE + +π_eval = π_flip + + +# ### Proposed: replace future with the value function for the evaluation policy +rng_annot = np.random.default_rng(seed=123456789) +rng_cf = np.random.default_rng(42424242) + + +df_va_all2 = pd.read_pickle(f'results/vaso_eps_{eps_str}-annotOpt_df_seed2_aug_step.pkl') +V_H_eval, Q_H_eval, J_eval = policy_eval_helper(π_eval) + + +df_results_v2 = [] +for run in range(runs): + df_va = df_va_all2.set_index('pt_id').loc[200000+run*run_idx_length:200000+run*run_idx_length + N_val - 1+0.999].reset_index() + + df_va['map_pt_id'] = df_va['pt_id'].apply(np.floor).astype(int) + df_va = df_va.drop_duplicates(['map_pt_id', 'Time', 'State', 'Action']) \ + .sort_values(by=['map_pt_id', 'Time', 'pt_id']).reset_index(drop=True) + df_va['Weight'] = np.nan + + # only draw cf annot with probability pcf + for i, row in tqdm(list(df_va.iterrows()), disable=True): + map_pt_id = row['map_pt_id'] + h = int(row['Time']) + s = int(row['State']) + if row['NextState'] in [1442]: + a_cf = int(row['Action']) + if rng_cf.uniform() < annot_ratio: + df_va.loc[i, 'Reward'] = Q_H_eval[h][s, a_cf] + rng_annot.normal(0, annot_noise) + else: + df_va.loc[i, 'Reward'] = np.nan + + # impute the missing annots + for s, a in itertools.product(range(nS), range(nA)): + sa_cf_mask = (df_va['State'] == s) & (df_va['Action'] == a) & (df_va['NextState'] == 1442) + if not df_va.loc[sa_cf_mask, 'Reward'].notnull().any(): + continue + avg_annot = df_va.loc[sa_cf_mask, 'Reward'].dropna().mean() + df_va.loc[sa_cf_mask, 'Reward'] = df_va.loc[sa_cf_mask, 'Reward'].fillna(avg_annot) + + # only use the cf annot with probability pcf + keep_mask = [] + for i, row in tqdm(list(df_va.iterrows()), disable=True): + map_pt_id = row['map_pt_id'] + h = int(row['Time']) + s = int(row['State']) + if row['NextState'] in [1440, 1441]: + keep_mask.append(True) + elif row['NextState'] in [1442]: + if not np.isnan(row['Reward']): + keep_mask.append(True) + else: + keep_mask.append(False) + else: + keep_mask.append(True) + df_va = df_va.loc[keep_mask] + + # assign weights + for i, row in tqdm(list(df_va.iterrows()), disable=True): + map_pt_id = row['map_pt_id'] + h = int(row['Time']) + s = int(row['State']) + if row['NextState'] in [1440, 1441]: + df_va.loc[i, 'Weight'] = 1.0 + elif row['NextState'] in [1442]: + a_cf = int(row['Action']) + a_f = int(a_cf+1-2*(a_cf%2)) + df_va.loc[i, 'Weight'] = weight_a_sa[s, a_f, a_cf] + df_va.loc[(df_va['map_pt_id'] == map_pt_id) + & (df_va['Time'] == h) + & (df_va['Action'] == a_f), 'Weight'] = weight_a_sa[s, a_f, a_f] + else: + pass + df_va['Weight'] = df_va['Weight'].fillna(1.0) + + # OPE - WIS/WDR prep + v2_data_va = format_data_tensor_cf(df_va) + v2_pi_b_val = compute_augmented_behavior_policy_h(df_va) + + # OPE - WIS + v2_IS_value, v2_WIS_value, v2_ESS_info = OPE_PDIS_h(v2_data_va, v2_pi_b_val, π_eval, gamma, epsilon=0.0) + + df_results_v2.append([v2_IS_value, v2_WIS_value, v2_ESS_info['ESS1']]) + pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1']).to_csv(out_fname, index=False) + +df_results_v2 = pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1']) +df_results_v2.to_csv(out_fname, index=False) diff --git a/sepsisSim/experiments/exp-5-annotBehConvertedAM.py b/sepsisSim/experiments/exp-5-annotBehConvertedAM.py new file mode 100644 index 0000000..8a310e8 --- /dev/null +++ b/sepsisSim/experiments/exp-5-annotBehConvertedAM.py @@ -0,0 +1,393 @@ +# ## Simulation parameters +exp_name = 'exp-FINAL-5' +eps = 0.10 +eps_str = '0_1' + +run_idx_length = 1_000 +N_val = 1_000 +runs = 50 + +# Number of action-flipped states +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--flip_num', type=int) +parser.add_argument('--flip_seed', type=int) +args = parser.parse_args() +pol_flip_num = args.flip_num +pol_flip_seed = args.flip_seed + +pol_name = f'flip{pol_flip_num}_seed{pol_flip_seed}' +out_fname = f'./results/{exp_name}/vaso_eps_{eps_str}-{pol_name}-aug_step-annotBehConvertedAM.csv' + +import numpy as np +import pandas as pd + +df_tmp = None +try: + df_tmp = pd.read_csv(out_fname) +except: + pass + +if df_tmp is not None: + print('File exists') + quit() + +from tqdm import tqdm +from collections import defaultdict +import pickle +import itertools +import copy +import random +import itertools +import joblib +from joblib import Parallel, delayed + +from OPE_utils import ( + compute_behavior_policy, + compute_empirical_MDP, +) +from OPE_utils_new import ( + format_data_tensor, + policy_eval_analytic_finite, + OPE_IS_h, + compute_behavior_policy_h, +) + + +NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP +G_min = -1 # the minimum possible return +G_max = 1 # the maximum possible return +nS, nA = 1442, 8 + +PROB_DIAB = 0.2 + +# Ground truth MDP model +MDP_parameters = joblib.load('../data/MDP_parameters.joblib') +P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next) +R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A) +nS, nA = R.shape +gamma = 0.99 + +# unif rand isd, mixture of diabetic state +isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib') +isd = (isd > 0).astype(float) +isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB) +isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB) + +# Precomputed optimal policy +π_star = joblib.load('../data/π_star.joblib') + + +def policy_eval_helper(π, P=P, R=R, isd=isd): + V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H) + Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R] + J = isd @ V_H[0] + # Check recursive relationships + assert len(Q_H) == H + assert len(V_H) == H + assert np.all(Q_H[-1] == R) + assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1]) + assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2]) + return V_H, Q_H, J + +def iqm(x): + return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None) + + +# ## Load data +input_dir = f'../datagen/vaso_eps_{eps_str}-100k/' + +def load_data(fname): + print('Loading data', fname, '...', end='') + df_data = pd.read_csv('{}/{}'.format(input_dir, fname)).rename(columns={'State_idx': 'State'})#[['pt_id', 'Time', 'State', 'Action', 'Reward']] + + # Assign next state + df_data['NextState'] = [*df_data['State'].iloc[1:].values, -1] + df_data.loc[df_data.groupby('pt_id')['Time'].idxmax(), 'NextState'] = -1 + df_data.loc[(df_data['Reward'] == -1), 'NextState'] = 1440 + df_data.loc[(df_data['Reward'] == 1), 'NextState'] = 1441 + + assert ((df_data['Reward'] != 0) == (df_data['Action'] == -1)).all() + + print('DONE') + return df_data + + +# df_seed1 = load_data('1-features.csv') # tr +df_seed2 = load_data('2-features.csv') # va + + +# ## C-PDIS code + +def compute_augmented_behavior_policy_h(df_data): + πh_b = np.zeros((H, nS, nA)) + hsa_counts = df_data.groupby(['Time', 'State', 'Action'])[['Weight']].sum().rename(columns={'Weight': 'count'}).reset_index() + + try: + for i, row in hsa_counts.iterrows(): + h, s, a = int(row['Time']), int(row['State']), int(row['Action']) + count = row['count'] + if row['Action'] == -1: + πh_b[h, s, :] = count + else: + πh_b[h, s, a] = count + except: + print(h,s,a) + raise + # import pdb + # pdb.set_trace() + + # assume uniform action probabilities in unobserved states + unobserved_states = (πh_b.sum(axis=-1) == 0) + πh_b[unobserved_states, :] = 1 + + # normalize action probabilities + πh_b = πh_b / πh_b.sum(axis=-1, keepdims=True) + + return πh_b + + +def format_data_tensor_cf(df_data, id_col='map_pt_id'): + """ + Converts data from a dataframe to a tensor + - df_data: pd.DataFrame with columns [id_col, Time, State, Action, Reward, NextState] + - id_col specifies the index column to group episodes + - data_tensor: integer tensor of shape (N, NSTEPS, 5) with the last last dimension being [t, s, a, r, s'] + """ + data_dict = dict(list(df_data.groupby(id_col))) + N = len(data_dict) + data_tensor = np.zeros((N, 2*NSTEPS, 6), dtype=float) + data_tensor[:, :, 0] = -1 # initialize all time steps to -1 + data_tensor[:, :, 2] = -1 # initialize all actions to -1 + data_tensor[:, :, 1] = -1 # initialize all states to -1 + data_tensor[:, :, 4] = -1 # initialize all next states to -1 + data_tensor[:, :, 5] = np.nan # initialize all weights to NaN + + for i, (pt_id, df_values) in tqdm(enumerate(data_dict.items()), disable=True): + values = df_values.set_index(id_col)[['Time', 'State', 'Action', 'Reward', 'NextState', 'Weight']].values + data_tensor[i, :len(values), :] = values + return data_tensor + + +def OPE_PDIS_h(data, π_b, π_e, γ, epsilon=0.01): + """ + - π_b, π_e: behavior/evaluation policy, shape (S,A) + """ + # Get a soft version of the evaluation policy for WIS + π_e_soft = π_e.astype(float) * (1 - epsilon*2) + π_unif * epsilon*2 + + # # Get a soft version of the behavior policy for WIS + # π_b_soft = π_b * (1 - epsilon) + epsilon / nA + + # Apply WIS + return _pdis_h(data, π_b, π_e_soft, γ) + +def _pdis_h(data, π_b, π_e, γ): + # For each original trajectory + v_all, rho_all = [], [] + for i, data_i in enumerate(data): + # Get all trajectories based on this trajectory + t_l = data_i[..., 0].astype(int) + s_l = data_i[..., 1].astype(int) + a_l = data_i[..., 2].astype(int) + r_l = data_i[..., 3].astype(float) + snext_l = data_i[..., 4].astype(int) + w_l = data_i[..., 5].astype(float) + + # Per-transition importance ratios + p_b = π_b[t_l, s_l, a_l] + p_e = π_e[s_l, a_l] + + # Deal with variable length sequences by setting ratio to 1 + terminated_idx = (a_l == -1) + terminating_idx = (s_l != -1) & (a_l == -1) + p_b[terminated_idx] = np.nan + p_e[terminated_idx] = np.nan + p_b[terminating_idx] = 1 + p_e[terminating_idx] = 1 + + # Per-step cumulative importance ratios + rho_t = (p_e / p_b) + + # # Last observed step of each trajectory + # idx_last = np.array([np.max(np.nonzero(s_l[row] != -1)) for row in range(len(s_l))]) + + # Initialize value to 0, importance ratio to 1 + v = 0 + rho_cum = 1 + + # Iterate backwards from step H to 1 + for h in reversed(range(H)): + # only start computing from the last observed step + if not (t_l == h).any(): + continue + + # do we have counterfactual annotation for this step? + if (t_l == h).sum() > 1: + # if we have counterfactual annotations for this step + j_all = np.argwhere(t_l == h).ravel() + assert np.isclose(w_l[j_all].sum(), 1) # weights add up to 1 + + # Identify factual transition and counterfactual annotations + f_, cf_ = [], [] + for j in j_all: + if snext_l[j] == 1442: # counterfactual annotation have dummy next state + cf_.append(j) + else: + f_.append(j) + assert len(f_) == 1 # there should only be one factual transition + f_ = f_[0] + v = w_l[f_]*rho_t[f_]*(r_l[f_]+γ*v) + np.sum([w_l[j]*rho_t[j]*r_l[j] for j in cf_]) + rho_cum = rho_cum * (w_l[f_]*rho_t[f_]) + np.sum([w_l[j]*rho_t[j] for j in cf_]) + else: + # we don't have counterfactual annotations for this step + # there should only be one trajectory and that must be the original traj + j = (t_l == h).argmax() + assert ~np.isnan(p_e[j]) + assert w_l[j] == 1.0 + v = rho_t[j] * (r_l[j]+γ*v) + rho_cum = rho_cum * rho_t[j] + + v_all.append(v) + rho_all.append(rho_cum) + + v_all = np.array(v_all) + rho_all = np.array(rho_all) + is_value = np.nansum(v_all) / len(rho_all) + wis_value = np.nansum(v_all) / np.nansum(rho_all) + rho_norm = rho_all / np.nansum(rho_all) + rho_nna = rho_all[~np.isnan(rho_all)] + rho_norm_nna = rho_norm[~np.isnan(rho_norm)] + ess1 = 1 / np.nansum(rho_norm_nna ** 2) + ess1_ = (np.nansum(rho_nna)) ** 2 / (np.nansum(rho_nna ** 2)) + ess2 = 1. / np.nanmax(rho_norm) + return is_value, wis_value, { + 'ESS1': ess1, 'ESS2': ess2, + 'rho': rho_all, 'rho_norm': rho_norm_nna, + } + + +## Default weighting scheme for C-PDIS + +weight_a_sa = np.zeros((nS, nA, nA)) + +# default weight if no counterfactual actions +for a in range(nA): + weight_a_sa[:, a, a] = 1 + +# split equally between factual and counterfactual actions +for s in range(nS): + a = π_star.argmax(axis=1)[s] + a_tilde = a+1-2*(a%2) + weight_a_sa[s, a, a] = 0.5 + weight_a_sa[s, a, a_tilde] = 0.5 + weight_a_sa[s, a_tilde, a] = 0.5 + weight_a_sa[s, a_tilde, a_tilde] = 0.5 + +assert np.all(weight_a_sa.sum(axis=-1) == 1) + + +# ## Policies + +# vaso unif, mv abx optimal +π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) + + +# ### Behavior policy + +# vaso eps=0.5, mv abx optimal +π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_beh[π_star == 1] = 1-eps +π_beh[π_beh == 0.5] = eps + +V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh) +J_beh + + +# ### Optimal policy +V_H_star, Q_H_star, J_star = policy_eval_helper(π_star) +J_star + + +# ### flip action for x% states + +rng_flip = np.random.default_rng(pol_flip_seed) +flip_states = rng_flip.choice(range(1440), pol_flip_num, replace=False) + +π_tmp = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2) +π_flip = π_tmp.copy() +π_flip[π_tmp == 0.5] = 0 +π_flip[π_star == 1] = 1 +for s in flip_states: + π_flip[s, π_tmp[s] == 0.5] = 1 + π_flip[s, π_star[s] == 1] = 0 +assert π_flip.sum(axis=1).mean() == 1 + +# np.savetxt(f'./results/{exp_name}/policy_{pol_name}.txt', π_flip) + + +# ## Compare OPE + +π_eval = π_flip + + +# ### Proposed: replace future with the value function for the evaluation policy + +df_va_all2 = pd.read_pickle(f'results/vaso_eps_{eps_str}-annotOpt_df_seed2_aug_step.pkl') +V_H_eval, Q_H_eval, J_eval = policy_eval_helper(π_eval) + + +df_results_v2 = [] +for run in range(runs): + + # Original dataset + df_va_orig = df_seed2.set_index('pt_id').loc[200000+run*run_idx_length:200000+run*run_idx_length + N_val - 1].reset_index()[ + ['pt_id', 'Time', 'State', 'Action', 'Reward', 'NextState'] + ] + df_orig = df_va_orig[['pt_id', 'Time', 'State', 'Action', 'Reward', 'NextState']] + data_orig = format_data_tensor(df_orig) + pi_b_orig = compute_behavior_policy(df_orig) + + # Build approx MDP and compute Q-functions for both π_b and π_e + P_approx, R_approx, isd_approx = compute_empirical_MDP(df_va_orig) + V_H_eval_approx, Q_H_eval_approx, J_eval_approx = policy_eval_helper(π_eval, P=P_approx, R=R_approx, isd=isd_approx) + V_H_beh_approx, Q_H_beh_approx, J_beh_approx = policy_eval_helper(pi_b_orig, P=P_approx, R=R_approx, isd=isd_approx) + + # Augmented dataset + df_va = df_va_all2.set_index('pt_id').loc[200000+run*run_idx_length:200000+run*run_idx_length + N_val - 1+0.999].reset_index() + df_va['map_pt_id'] = df_va['pt_id'].apply(np.floor).astype(int) + df_va = df_va.drop_duplicates(['map_pt_id', 'Time', 'State', 'Action']) \ + .sort_values(by=['map_pt_id', 'Time', 'pt_id']).reset_index(drop=True) + df_va['Weight'] = np.nan + + for i, row in tqdm(list(df_va.iterrows()), disable=True): + map_pt_id = row['map_pt_id'] + h = int(row['Time']) + s = int(row['State']) + if row['NextState'] in [1440, 1441]: + df_va.loc[i, 'Weight'] = 1.0 + elif row['NextState'] in [1442]: + a_cf = int(row['Action']) + a_f = int(a_cf+1-2*(a_cf%2)) + df_va.loc[i, 'Weight'] = weight_a_sa[s, a_f, a_cf] + df_va.loc[(df_va['map_pt_id'] == map_pt_id) + & (df_va['Time'] == h) + & (df_va['Action'] == a_f), 'Weight'] = weight_a_sa[s, a_f, a_f] + annot_cf = Q_H_beh[h][s,a_cf] - Q_H_beh_approx[h][s,a_cf] + Q_H_eval_approx[h][s,a_cf] + df_va.loc[i, 'Reward'] = annot_cf + else: + pass + df_va['Weight'] = df_va['Weight'].fillna(1) + + # OPE - WIS/WDR prep + v2_data_va = format_data_tensor_cf(df_va) + v2_pi_b_val = compute_augmented_behavior_policy_h(df_va) + + # OPE - WIS + v2_IS_value, v2_WIS_value, v2_ESS_info = OPE_PDIS_h(v2_data_va, v2_pi_b_val, π_eval, gamma, epsilon=0.0) + + df_results_v2.append([v2_IS_value, v2_WIS_value, v2_ESS_info['ESS1']]) + pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1']).to_csv(out_fname, index=False) + +df_results_v2 = pd.DataFrame(df_results_v2, columns=['IS_value', 'WIS_value', 'ESS1']) +df_results_v2.to_csv(out_fname, index=False) diff --git a/sepsisSim/experiments/fig/annot_KL.pdf b/sepsisSim/experiments/fig/annot_KL.pdf new file mode 100644 index 0000000..542fc38 Binary files /dev/null and b/sepsisSim/experiments/fig/annot_KL.pdf differ diff --git a/sepsisSim/experiments/fig/annot_KL_0.pdf b/sepsisSim/experiments/fig/annot_KL_0.pdf new file mode 100644 index 0000000..2a7fffd Binary files /dev/null and b/sepsisSim/experiments/fig/annot_KL_0.pdf differ diff --git a/sepsisSim/experiments/fig/annotation_legend.pdf b/sepsisSim/experiments/fig/annotation_legend.pdf new file mode 100644 index 0000000..7542831 Binary files /dev/null and b/sepsisSim/experiments/fig/annotation_legend.pdf differ diff --git a/sepsisSim/experiments/fig/annotation_missing.pdf b/sepsisSim/experiments/fig/annotation_missing.pdf new file mode 100644 index 0000000..491a8a2 Binary files /dev/null and b/sepsisSim/experiments/fig/annotation_missing.pdf differ diff --git a/sepsisSim/experiments/fig/annotation_noisy.pdf b/sepsisSim/experiments/fig/annotation_noisy.pdf new file mode 100644 index 0000000..c5b2f4c Binary files /dev/null and b/sepsisSim/experiments/fig/annotation_noisy.pdf differ diff --git a/sepsisSim/experiments/fig/annotation_noisy_missing10.pdf b/sepsisSim/experiments/fig/annotation_noisy_missing10.pdf new file mode 100644 index 0000000..53e2054 Binary files /dev/null and b/sepsisSim/experiments/fig/annotation_noisy_missing10.pdf differ diff --git a/sepsisSim/experiments/fig/sepsisSim-policies.pdf b/sepsisSim/experiments/fig/sepsisSim-policies.pdf new file mode 100644 index 0000000..4177c16 Binary files /dev/null and b/sepsisSim/experiments/fig/sepsisSim-policies.pdf differ diff --git a/sepsisSim/experiments/plots--legend.ipynb b/sepsisSim/experiments/plots--legend.ipynb new file mode 100644 index 0000000..5a993f2 --- /dev/null +++ b/sepsisSim/experiments/plots--legend.ipynb @@ -0,0 +1,879 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "61afdeb1-b639-4f9f-a936-59d50ac084a9", + "metadata": {}, + "outputs": [], + "source": [ + "%config InlineBackend.figure_formats = ['svg']\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib\n", + "matplotlib.rcParams['font.sans-serif'] = ['Arial']" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "dc8cb65e-1d87-4857-82a3-489da0c05e5a", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-05-14T13:07:04.330084\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(0.1,0.1))\n", + "plt.legend(\n", + " [matplotlib.lines.Line2D([], [], c='tab:blue'), \n", + " matplotlib.lines.Line2D([], [], c='w'),\n", + " matplotlib.lines.Line2D([], [], c='gray', ls=(0,(0.5,0.5))), \n", + " matplotlib.lines.Line2D([], [], c='k', marker='*', linestyle='none'),\n", + " matplotlib.lines.Line2D([], [], c='r', ls=(0,(1,1))), \n", + " matplotlib.lines.Line2D([], [], c='g', ls='--'), \n", + " ],\n", + " ['C*-PDIS, noisy annot.', \n", + " '',\n", + " 'PDIS (baseline)', \n", + " 'ideal case', \n", + " 'C-PDIS , w/o imputation', \n", + " 'C*-PDIS, w/ imputation'],\n", + " loc='center', bbox_to_anchor=(0.5,1.0),\n", + " ncol=3, handlelength=1.333, handletextpad=0.6, columnspacing=1,\n", + ")\n", + "ax.axis(False)\n", + "plt.savefig('fig/annotation_legend.pdf', bbox_inches='tight')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f76b409d-35e6-4ca6-9dcf-58771f06f1b9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sepsisSim/experiments/plots-analyses.ipynb b/sepsisSim/experiments/plots-analyses.ipynb new file mode 100644 index 0000000..73e29c0 --- /dev/null +++ b/sepsisSim/experiments/plots-analyses.ipynb @@ -0,0 +1,5311 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "89b14dad-9e23-4f54-8bdf-dd5eebd8c025", + "metadata": {}, + "outputs": [], + "source": [ + "# ## Simulation parameters\n", + "exp_name = 'exp-FINAL'\n", + "eps = 0.10\n", + "eps_str = '0_1'\n", + "\n", + "run_idx_length = 1_000\n", + "N_val = 1_000\n", + "runs = 50" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d811a8da-cd5a-4aba-a268-9b01d5c13442", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib\n", + "matplotlib.rcParams['font.sans-serif'] = ['Arial']\n", + "%config InlineBackend.figure_formats = ['svg']" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "11339963-17b7-44f7-9823-59ac511a4a8a", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "from collections import defaultdict\n", + "import pickle\n", + "import itertools\n", + "import copy\n", + "import random\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import scipy.stats\n", + "from sklearn import metrics\n", + "import itertools\n", + "\n", + "import joblib\n", + "from joblib import Parallel, delayed" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "a510145e-f9df-4460-a15d-377f2b4a6072", + "metadata": {}, + "outputs": [], + "source": [ + "plt.rcParams.update({\n", + " \"text.usetex\": True,\n", + " \"font.family\": \"sans-serif\",\n", + " \"font.sans-serif\": [\"Helvetica\"],\n", + "})" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0d61671d-b32e-4ec5-8859-85a5fae5c3c4", + "metadata": {}, + "outputs": [], + "source": [ + "from OPE_utils_new import (\n", + " format_data_tensor,\n", + " policy_eval_analytic_finite,\n", + " OPE_IS_h,\n", + " compute_behavior_policy_h,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0c573be0-9afb-4dc0-b018-4c62a0ce15f7", + "metadata": {}, + "outputs": [], + "source": [ + "def policy_eval_helper(π):\n", + " V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H)\n", + " Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R]\n", + " J = isd @ V_H[0]\n", + " # Check recursive relationships\n", + " assert len(Q_H) == H\n", + " assert len(V_H) == H\n", + " assert np.all(Q_H[-1] == R)\n", + " assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1])\n", + " assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2])\n", + " return V_H, Q_H, J" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a0ab1c7c-bcd3-43ad-848b-93268fa1a2b6", + "metadata": {}, + "outputs": [], + "source": [ + "def iqm(x):\n", + " return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "467e2e02-2f18-4fa9-9d17-3b9ae0e4b8dd", + "metadata": {}, + "outputs": [], + "source": [ + "NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP\n", + "G_min = -1 # the minimum possible return\n", + "G_max = 1 # the maximum possible return\n", + "nS, nA = 1442, 8\n", + "\n", + "PROB_DIAB = 0.2" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fdf1dfc0-7d33-4f7d-952c-5b77668f2968", + "metadata": {}, + "outputs": [], + "source": [ + "# Ground truth MDP model\n", + "MDP_parameters = joblib.load('../data/MDP_parameters.joblib')\n", + "P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next)\n", + "R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A)\n", + "nS, nA = R.shape\n", + "gamma = 0.99\n", + "\n", + "# unif rand isd, mixture of diabetic state\n", + "isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib')\n", + "isd = (isd > 0).astype(float)\n", + "isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB)\n", + "isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c210ea9b-a6e0-42e6-be21-35a8e439c0c8", + "metadata": {}, + "outputs": [], + "source": [ + "# Precomputed optimal policy\n", + "π_star = joblib.load('../data/π_star.joblib')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "9342ffbd-00f1-4916-b418-054a289f20bd", + "metadata": { + "tags": [] + }, + "source": [ + "## Policies" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "645c4315-1109-4b41-b7ec-177e03e3343d", + "metadata": {}, + "outputs": [], + "source": [ + "# vaso unif, mv abx optimal\n", + "π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "32ece6ab-6dca-479a-ab85-0cc3233e9f72", + "metadata": {}, + "source": [ + "### Behavior policy" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "639c3b25-c715-4a90-8d1a-094e4478b9c9", + "metadata": {}, + "outputs": [], + "source": [ + "# vaso eps=0.5, mv abx optimal\n", + "π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)\n", + "π_beh[π_star == 1] = 1-eps\n", + "π_beh[π_beh == 0.5] = eps" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "50a792f2-1f9d-40f7-89db-bd9245eabdf3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.25038354793851164" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh)\n", + "J_beh" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "8d88637a-8298-4f1f-88ef-ab11b05ec08b", + "metadata": {}, + "source": [ + "### Optimal policy" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "d849f76e-01a1-48f8-9d8a-e9d0a880cadf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.40877179296760224" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V_H_star, Q_H_star, J_star = policy_eval_helper(π_star)\n", + "J_star" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "44136cb8-d2e0-459d-8ee1-f15b2ed8a62c", + "metadata": {}, + "source": [ + "### flip action for x% states" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "97e67e52-09c6-4f0d-b435-1304a88a98d6", + "metadata": {}, + "outputs": [], + "source": [ + "def get_π_flip(pol_flip_seed, pol_flip_num):\n", + " rng_flip = np.random.default_rng(pol_flip_seed)\n", + " flip_states = rng_flip.choice(range(1440), pol_flip_num, replace=False)\n", + "\n", + " π_tmp = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)\n", + " π_flip = π_tmp.copy()\n", + " π_flip[π_tmp == 0.5] = 0\n", + " π_flip[π_star == 1] = 1\n", + " for s in flip_states:\n", + " π_flip[s, π_tmp[s] == 0.5] = 1\n", + " π_flip[s, π_star[s] == 1] = 0\n", + " assert π_flip.sum(axis=1).mean() == 1\n", + " return π_flip" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "67c8441f-07bc-425a-ad5b-26b970653986", + "metadata": {}, + "outputs": [], + "source": [ + "πs_flip_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " π_flip = get_π_flip(flip_seed, flip_num)\n", + " πs_flip_list.append(π_flip)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "976a7213-6cb1-4511-a888-80173c967581", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████████████████████████████████████████████████████████████████████████████████| 26/26 [00:20<00:00, 1.24it/s]\n" + ] + } + ], + "source": [ + "v_list = []\n", + "for π_eval in tqdm(πs_flip_list):\n", + " _, _, J_eval = policy_eval_helper(π_eval)\n", + " v_list.append(J_eval)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "c421cb40-2069-4d04-be56-2456dc6810b4", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-05-08T12:55:34.507399\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(v_list, ls='none', marker='.')\n", + "plt.axhline(J_beh, c='k', ls='--', label='behavior', zorder=0)\n", + "plt.xlabel('policy index')\n", + "plt.ylabel('policy value')\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "62b4937e-484a-4740-90ea-8cf91fd19578", + "metadata": {}, + "source": [ + "## Load results" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "73d56e0a-de4a-4a90-9c3a-c4b78a02e64b", + "metadata": {}, + "outputs": [], + "source": [ + "df_results_0 = pd.read_csv(f'./results/{exp_name}-1/vaso_eps_{eps_str}-observed.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "d23aed3c-898f-4c44-b32b-20f744ab3a7d", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_orig_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-1/vaso_eps_{eps_str}-{pol_name}-orig.csv')\n", + " dfs_results_orig_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "074e313e-ccc4-475f-857b-12b83da8b48d", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotEval_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-2/vaso_eps_{eps_str}-{pol_name}-aug_step-annotEval.csv')\n", + " dfs_results_annotEval_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "6cb0164e-1c33-486c-ad59-ad1ebf9a2e01", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotBeh_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-2/vaso_eps_{eps_str}-{pol_name}-aug_step-annotBeh.csv')\n", + " dfs_results_annotBeh_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "1266a269-b485-4b2d-bd1b-c243c84da908", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotBehConvertedAM_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-2/vaso_eps_{eps_str}-{pol_name}-aug_step-annotBehConvertedAM.csv')\n", + " dfs_results_annotBehConvertedAM_list.append(df_results_)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "df0a151c-99ca-4187-bb6b-cf4952c9724d", + "metadata": {}, + "source": [ + "### Expected state distribution under π_b" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "ab2c209b-86f5-4332-85b4-5b6570e93b0c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading data 2-features.csv ...DONE\n" + ] + } + ], + "source": [ + "input_dir = f'../datagen/vaso_eps_{eps_str}-100k/'\n", + "def load_data(fname):\n", + " print('Loading data', fname, '...', end='')\n", + " df_data = pd.read_csv('{}/{}'.format(input_dir, fname)).rename(columns={'State_idx': 'State'})#[['pt_id', 'Time', 'State', 'Action', 'Reward']]\n", + "\n", + " # Assign next state\n", + " df_data['NextState'] = [*df_data['State'].iloc[1:].values, -1]\n", + " df_data.loc[df_data.groupby('pt_id')['Time'].idxmax(), 'NextState'] = -1\n", + " df_data.loc[(df_data['Reward'] == -1), 'NextState'] = 1440\n", + " df_data.loc[(df_data['Reward'] == 1), 'NextState'] = 1441\n", + "\n", + " assert ((df_data['Reward'] != 0) == (df_data['Action'] == -1)).all()\n", + "\n", + " print('DONE')\n", + " return df_data\n", + "\n", + "df_seed2 = load_data('2-features.csv') # va" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "36082d2c-b862-4930-b676-393781593826", + "metadata": {}, + "outputs": [], + "source": [ + "state_dist = df_seed2['State'].value_counts().reindex(range(nS), fill_value=0)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "1f8bff6a-1a52-4b51-8d6e-89c8c6ffd5c6", + "metadata": {}, + "source": [ + "## Plots" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "dd9e06d4-5a5c-415d-b3bb-e743628d7bc7", + "metadata": {}, + "outputs": [], + "source": [ + "def rmse(y1, y2):\n", + " return np.sqrt(np.mean(np.square(y1-y2)))" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "6d74e0f2-39d7-40ee-8541-10544a05462e", + "metadata": {}, + "outputs": [], + "source": [ + "def kl_divergence(p, q):\n", + " p_pnz, q_pnz = p[p != 0], q[p != 0]\n", + " return np.sum(p_pnz * np.log(p_pnz) - np.log(q_pnz), 0)" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "d7b11f55-2ac6-423f-ab43-e4685fbefc9c", + "metadata": {}, + "outputs": [], + "source": [ + "rmse_dict = {}\n", + "stderr_dict = {}\n", + "for name, dfs_results_list in {\n", + " 'orig': dfs_results_orig_list,\n", + " 'annotEval': dfs_results_annotEval_list, \n", + " 'annotBeh': dfs_results_annotBeh_list,\n", + " 'annotBehConvertedAM': dfs_results_annotBehConvertedAM_list,\n", + "}.items():\n", + " rmse_list = []\n", + " stderr_list = []\n", + " for idx, π_idx in enumerate(πs_flip_list):\n", + " J_eval = v_list[idx]\n", + " df_results = dfs_results_list[idx]\n", + " v_IS, v_WIS = list(df_results['IS_value']), list(df_results['WIS_value'])\n", + " rmse_value = rmse(v_IS + v_WIS, J_eval)\n", + " rmse_list.append(rmse_value)\n", + " stderr_value = np.std(v_IS + v_WIS - J_eval)\n", + " stderr_list.append(stderr_value)\n", + " \n", + " rmse_dict[name] = rmse_list\n", + " stderr_dict[name] = stderr_list" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "62acebac-3a89-4075-81ed-e54a037cd6bc", + "metadata": {}, + "outputs": [], + "source": [ + "kl_list = []\n", + "for idx, π_idx in enumerate(πs_flip_list):\n", + " kl_ = [kl_divergence(πe_s, πb_s) for πb_s, πe_s in zip(π_beh, π_idx)]\n", + " kl_list.append(np.average(kl_, weights=state_dist))" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "e93ffa04-20df-41c9-aef1-fd95326abff6", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-05-08T16:53:12.324033\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(3,2))\n", + "plt.plot(kl_list, v_list, ls='none', marker='D', c='k', markersize=3, alpha=0.8)\n", + "plt.axhline(J_beh, c='gray', ls='--', label='behavior', zorder=0)\n", + "plt.xlabel('$D_{\\mathrm{KL}}(\\pi_b || \\pi_e)$')\n", + "plt.ylabel('$v(\\pi_e)$')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "d2ca8213-5ced-43de-8f3f-4ab477e0d976", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
01
00.0130150.012592
10.0130030.012563
20.0136300.013513
30.0135150.013259
40.0138440.013433
50.0134780.013025
60.0135050.013314
70.0104730.010330
80.0122400.012103
90.0115050.011412
100.0131750.013014
110.0140950.013603
120.0169480.016692
130.0144190.014211
140.0122070.012153
150.0147580.014442
160.0151810.015076
170.0141900.013969
180.0137080.013577
190.0151500.014883
200.0168940.016700
210.0124080.012404
220.0113220.011313
230.0118160.011751
240.0149730.014799
250.0127470.012623
\n", + "
" + ], + "text/plain": [ + " 0 1\n", + "0 0.013015 0.012592\n", + "1 0.013003 0.012563\n", + "2 0.013630 0.013513\n", + "3 0.013515 0.013259\n", + "4 0.013844 0.013433\n", + "5 0.013478 0.013025\n", + "6 0.013505 0.013314\n", + "7 0.010473 0.010330\n", + "8 0.012240 0.012103\n", + "9 0.011505 0.011412\n", + "10 0.013175 0.013014\n", + "11 0.014095 0.013603\n", + "12 0.016948 0.016692\n", + "13 0.014419 0.014211\n", + "14 0.012207 0.012153\n", + "15 0.014758 0.014442\n", + "16 0.015181 0.015076\n", + "17 0.014190 0.013969\n", + "18 0.013708 0.013577\n", + "19 0.015150 0.014883\n", + "20 0.016894 0.016700\n", + "21 0.012408 0.012404\n", + "22 0.011322 0.011313\n", + "23 0.011816 0.011751\n", + "24 0.014973 0.014799\n", + "25 0.012747 0.012623" + ] + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame([rmse_dict['annotEval'], stderr_dict['annotEval']]).T" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "id": "2a305d54-edc6-4d02-afef-6940b2c0f064", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-05-08T17:14:46.991419\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(3,2.5))\n", + "x = kl_list\n", + "y1 = rmse_dict['orig']\n", + "y2 = rmse_dict['annotEval']\n", + "y3 = rmse_dict['annotBeh']\n", + "y4 = rmse_dict['annotBehConvertedAM']\n", + "\n", + "w1, cov1 = np.polyfit(x, y1, 1, cov=True)\n", + "e1 = np.sqrt(np.diag(cov1))\n", + "w2, cov2 = np.polyfit(x, y2, 1, cov=True)\n", + "e2 = np.sqrt(np.diag(cov2))\n", + "w3, cov3 = np.polyfit(x, y3, 1, cov=True)\n", + "e3 = np.sqrt(np.diag(cov3))\n", + "w4, cov4 = np.polyfit(x, y4, 1, cov=True)\n", + "e4 = np.sqrt(np.diag(cov4))\n", + "\n", + "x0 = np.arange(0,1.1,0.1)\n", + "plt.plot(x, y1, ls='none', marker='.', mew=0, c='gray', alpha=0.8)\n", + "plt.plot(x, y3, ls='none', marker='.', mew=0, c='violet', alpha=0.8)\n", + "plt.plot(x, y4, ls='none', marker='.', mew=0, c='tab:pink', alpha=0.8)\n", + "plt.plot(x, y2, ls='none', marker='.', mew=0, c='tab:purple', alpha=0.8)\n", + "plt.plot(x0, np.poly1d(w1)(x0), ls='-', lw=2, c='gray', label='PDIS (baseline)')\n", + "plt.plot(x0, np.poly1d(w3)(x0), ls='-', lw=2, c='violet', label='C*-PDIS ($G = Q^{\\pi_b}$)')\n", + "plt.plot(x0, np.poly1d(w4)(x0), ls='-', lw=2, c='tab:pink', label='C*-PDIS ($G = Q^{\\pi_b} \\mapsto \\hat{Q}^{\\pi_e}$)')\n", + "plt.plot(x0, np.poly1d(w2)(x0), ls='-', lw=2, c='tab:purple', label='C*-PDIS ($G = Q^{\\pi_e}$)')\n", + "plt.fill_between(x0, np.poly1d(w1-e1)(x0), np.poly1d(w1+e1)(x0), fc='gray', alpha=0.2)\n", + "plt.fill_between(x0, np.poly1d(w2-e2)(x0), np.poly1d(w2+e2)(x0), fc='violet', alpha=0.2)\n", + "plt.fill_between(x0, np.poly1d(w3-e3)(x0), np.poly1d(w3+e3)(x0), fc='tab:pink', alpha=0.2)\n", + "plt.fill_between(x0, np.poly1d(w4-e4)(x0), np.poly1d(w4+e4)(x0), fc='tab:purple', alpha=0.2)\n", + "plt.legend(loc='lower center', bbox_to_anchor=(0.5,1.0), labelspacing=0.6)\n", + "plt.xticks([0,0.5,1])\n", + "plt.xlabel('$D_{\\mathrm{KL}}(\\pi_b || \\pi_e)$')\n", + "plt.ylabel('RMSE($v(\\pi_e),\\hat{v}(\\pi_e)$)')\n", + "plt.savefig('fig/annot_KL.pdf', bbox_inches='tight')\n", + "plt.yscale('log')\n", + "plt.savefig('fig/annot_KL_logscale.pdf', bbox_inches='tight')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "fcad73db-079b-415b-9360-6011ff570d55", + "metadata": {}, + "source": [ + "solve for best-fit line with zero intercept" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "id": "4bc5a07b-0601-4d07-89eb-b4645fdc1f0a", + "metadata": {}, + "outputs": [], + "source": [ + "df_results_onpolicy = pd.read_csv(f'./results/{exp_name}/vaso_eps_{eps_str}-onpolicy-orig.csv')\n", + "rmse_value_onpolicy = rmse(list(df_results_onpolicy['IS_value']) + list(df_results_onpolicy['WIS_value']), J_beh)" + ] + }, + { + "cell_type": "code", + "execution_count": 125, + "id": "102e1191-0628-4117-9bb8-7f8d2b20ea0f", + "metadata": {}, + "outputs": [], + "source": [ + "df_results_onpolicy_annotBeh = pd.read_csv(f'./results/{exp_name}a/vaso_eps_{eps_str}-onpolicy-aug_step-annotBeh.csv')\n", + "rmse_value_onpolicy_annotBeh = rmse(list(df_results_onpolicy_annotBeh['IS_value']) + list(df_results_onpolicy_annotBeh['WIS_value']), J_beh)" + ] + }, + { + "cell_type": "code", + "execution_count": 143, + "id": "dcf64c00-963a-48ba-b4a6-8132c78f6907", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-05-14T13:14:58.748960\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(3,2.5))\n", + "x = kl_list\n", + "y1 = rmse_dict['orig']\n", + "y2 = rmse_dict['annotEval']\n", + "y3 = rmse_dict['annotBeh']\n", + "y4 = rmse_dict['annotBehConvertedAM']\n", + "\n", + "w1, cov1 = np.polyfit(x, y1, 1, cov=True)\n", + "e1 = np.sqrt(np.diag(cov1))\n", + "w2, cov2 = np.polyfit(x, y2, 1, cov=True)\n", + "e2 = np.sqrt(np.diag(cov2))\n", + "w3, cov3 = np.polyfit(x, y3, 1, cov=True)\n", + "e3 = np.sqrt(np.diag(cov3))\n", + "w4, cov4 = np.polyfit(x, y4, 1, cov=True)\n", + "e4 = np.sqrt(np.diag(cov4))\n", + "\n", + "w1_, _, _, _ = np.linalg.lstsq(np.array(x)[:,np.newaxis], np.array(y1)-rmse_value_onpolicy, rcond=None)\n", + "w2_, _, _, _ = np.linalg.lstsq(np.array(x)[:,np.newaxis], np.array(y2)-rmse_value_onpolicy_annotBeh, rcond=None)\n", + "w3_, _, _, _ = np.linalg.lstsq(np.array(x)[:,np.newaxis], np.array(y3)-rmse_value_onpolicy_annotBeh, rcond=None)\n", + "w4_, _, _, _ = np.linalg.lstsq(np.array(x)[:,np.newaxis], np.array(y4)-rmse_value_onpolicy_annotBeh, rcond=None)\n", + "w1__ = np.array([w1_[0], rmse_value_onpolicy])\n", + "w2__ = np.array([w2_[0], rmse_value_onpolicy_annotBeh])\n", + "w2__ = np.array([0, rmse_value_onpolicy_annotBeh])\n", + "w3__ = np.array([w3_[0], rmse_value_onpolicy_annotBeh])\n", + "w4__ = np.array([w4_[0], rmse_value_onpolicy_annotBeh])\n", + "\n", + "\n", + "x0 = np.arange(0,1.1,0.1)\n", + "plt.plot(x0, np.poly1d(w1__)(x0), ls='-', lw=2, c='gray', label='PDIS (baseline)')\n", + "plt.plot(x0, np.poly1d(w3__)(x0), ls='-', lw=2, c='violet', label='C*-PDIS ($G = Q^{\\pi_b}$)')\n", + "plt.plot(x0, np.poly1d(w4__)(x0), ls='-', lw=2, c='darkorchid', label='C*-PDIS ($G = Q^{\\pi_b} \\mapsto \\hat{Q}^{\\pi_e}$)')\n", + "plt.plot(x0, np.poly1d(w2__)(x0), ls='-', lw=2, c='slateblue', label='C*-PDIS ($G = Q^{\\pi_e}$)')\n", + "plt.fill_between(x0, np.poly1d(w1__-e1)(x0), np.poly1d(w1__+e1)(x0), fc='gray', alpha=0.2)\n", + "plt.fill_between(x0, np.poly1d(w2__-e2)(x0), np.poly1d(w2__+e2)(x0), fc='violet', alpha=0.2)\n", + "plt.fill_between(x0, np.poly1d(w3__-e3)(x0), np.poly1d(w3__+e3)(x0), fc='darkorchid', alpha=0.2)\n", + "plt.fill_between(x0, np.poly1d(w4__-e4)(x0), np.poly1d(w4__+e4)(x0), fc='slateblue', alpha=0.2)\n", + "plt.plot(x, y2, ls='none', marker='.', mew=0, c='slateblue', alpha=0.8)\n", + "plt.plot(x, y4, ls='none', marker='.', mew=0, c='darkorchid', alpha=0.8)\n", + "plt.plot(x, y3, ls='none', marker='.', mew=0, c='violet', alpha=0.8)\n", + "plt.plot(x, y1, ls='none', marker='.', mew=0, c='gray', alpha=0.8)\n", + "ylim = ax.get_ylim()\n", + "plt.plot(0.9, 0.318, marker='*', c='k', clip_on=False, zorder=100)\n", + "plt.ylim(ylim)\n", + "plt.legend(loc='lower center', bbox_to_anchor=(0.5,1.0), labelspacing=0.6)\n", + "plt.xticks([0,0.5,1])\n", + "plt.xlabel('$D_{\\mathrm{KL}}(\\pi_b || \\pi_e)$')\n", + "plt.ylabel('RMSE($v(\\pi_e),\\hat{v}(\\pi_e)$)')\n", + "plt.savefig('fig/annot_KL_0.pdf', bbox_inches='tight')\n", + "# plt.yscale('log')\n", + "# plt.savefig('fig/annot_KL_logscale_0.pdf', bbox_inches='tight')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2089e964-9574-4456-8185-d874fb86f57c", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sepsisSim/experiments/plots-missing.ipynb b/sepsisSim/experiments/plots-missing.ipynb new file mode 100644 index 0000000..22da673 --- /dev/null +++ b/sepsisSim/experiments/plots-missing.ipynb @@ -0,0 +1,6751 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "89b14dad-9e23-4f54-8bdf-dd5eebd8c025", + "metadata": {}, + "outputs": [], + "source": [ + "# ## Simulation parameters\n", + "exp_name = 'exp-FINAL'\n", + "eps = 0.10\n", + "eps_str = '0_1'\n", + "\n", + "run_idx_length = 1_000\n", + "N_val = 1_000\n", + "runs = 50" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a510145e-f9df-4460-a15d-377f2b4a6072", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib\n", + "matplotlib.rcParams['font.sans-serif'] = ['Arial']\n", + "%config InlineBackend.figure_formats = ['svg']" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "11339963-17b7-44f7-9823-59ac511a4a8a", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "from collections import defaultdict\n", + "import pickle\n", + "import itertools\n", + "import copy\n", + "import random\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import scipy.stats\n", + "from sklearn import metrics\n", + "import itertools\n", + "\n", + "import joblib\n", + "from joblib import Parallel, delayed" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0d61671d-b32e-4ec5-8859-85a5fae5c3c4", + "metadata": {}, + "outputs": [], + "source": [ + "from OPE_utils_new import (\n", + " format_data_tensor,\n", + " policy_eval_analytic_finite,\n", + " OPE_IS_h,\n", + " compute_behavior_policy_h,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0c573be0-9afb-4dc0-b018-4c62a0ce15f7", + "metadata": {}, + "outputs": [], + "source": [ + "def policy_eval_helper(π):\n", + " V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H)\n", + " Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R]\n", + " J = isd @ V_H[0]\n", + " # Check recursive relationships\n", + " assert len(Q_H) == H\n", + " assert len(V_H) == H\n", + " assert np.all(Q_H[-1] == R)\n", + " assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1])\n", + " assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2])\n", + " return V_H, Q_H, J" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a0ab1c7c-bcd3-43ad-848b-93268fa1a2b6", + "metadata": {}, + "outputs": [], + "source": [ + "def iqm(x):\n", + " return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "467e2e02-2f18-4fa9-9d17-3b9ae0e4b8dd", + "metadata": {}, + "outputs": [], + "source": [ + "NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP\n", + "G_min = -1 # the minimum possible return\n", + "G_max = 1 # the maximum possible return\n", + "nS, nA = 1442, 8\n", + "\n", + "PROB_DIAB = 0.2" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fdf1dfc0-7d33-4f7d-952c-5b77668f2968", + "metadata": {}, + "outputs": [], + "source": [ + "# Ground truth MDP model\n", + "MDP_parameters = joblib.load('../data/MDP_parameters.joblib')\n", + "P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next)\n", + "R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A)\n", + "nS, nA = R.shape\n", + "gamma = 0.99\n", + "\n", + "# unif rand isd, mixture of diabetic state\n", + "isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib')\n", + "isd = (isd > 0).astype(float)\n", + "isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB)\n", + "isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c210ea9b-a6e0-42e6-be21-35a8e439c0c8", + "metadata": {}, + "outputs": [], + "source": [ + "# Precomputed optimal policy\n", + "π_star = joblib.load('../data/π_star.joblib')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "9342ffbd-00f1-4916-b418-054a289f20bd", + "metadata": { + "tags": [] + }, + "source": [ + "## Policies" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "645c4315-1109-4b41-b7ec-177e03e3343d", + "metadata": {}, + "outputs": [], + "source": [ + "# vaso unif, mv abx optimal\n", + "π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "32ece6ab-6dca-479a-ab85-0cc3233e9f72", + "metadata": {}, + "source": [ + "### Behavior policy" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "639c3b25-c715-4a90-8d1a-094e4478b9c9", + "metadata": {}, + "outputs": [], + "source": [ + "# vaso eps=0.5, mv abx optimal\n", + "π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)\n", + "π_beh[π_star == 1] = 1-eps\n", + "π_beh[π_beh == 0.5] = eps" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "50a792f2-1f9d-40f7-89db-bd9245eabdf3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.25038354793851164" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh)\n", + "J_beh" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "8d88637a-8298-4f1f-88ef-ab11b05ec08b", + "metadata": {}, + "source": [ + "### Optimal policy" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "d849f76e-01a1-48f8-9d8a-e9d0a880cadf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.40877179296760224" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V_H_star, Q_H_star, J_star = policy_eval_helper(π_star)\n", + "J_star" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "44136cb8-d2e0-459d-8ee1-f15b2ed8a62c", + "metadata": {}, + "source": [ + "### flip action for x% states" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "97e67e52-09c6-4f0d-b435-1304a88a98d6", + "metadata": {}, + "outputs": [], + "source": [ + "def get_π_flip(pol_flip_seed, pol_flip_num):\n", + " rng_flip = np.random.default_rng(pol_flip_seed)\n", + " flip_states = rng_flip.choice(range(1440), pol_flip_num, replace=False)\n", + "\n", + " π_tmp = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)\n", + " π_flip = π_tmp.copy()\n", + " π_flip[π_tmp == 0.5] = 0\n", + " π_flip[π_star == 1] = 1\n", + " for s in flip_states:\n", + " π_flip[s, π_tmp[s] == 0.5] = 1\n", + " π_flip[s, π_star[s] == 1] = 0\n", + " assert π_flip.sum(axis=1).mean() == 1\n", + " return π_flip" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "67c8441f-07bc-425a-ad5b-26b970653986", + "metadata": {}, + "outputs": [], + "source": [ + "πs_flip_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " π_flip = get_π_flip(flip_seed, flip_num)\n", + " πs_flip_list.append(π_flip)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "976a7213-6cb1-4511-a888-80173c967581", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████████████████████████████████████████████████████████████████████████████████| 26/26 [00:18<00:00, 1.38it/s]\n" + ] + } + ], + "source": [ + "v_list = []\n", + "for π_eval in tqdm(πs_flip_list):\n", + " _, _, J_eval = policy_eval_helper(π_eval)\n", + " v_list.append(J_eval)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "c421cb40-2069-4d04-be56-2456dc6810b4", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-30T11:24:55.462034\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(v_list, ls='none', marker='.')\n", + "plt.axhline(J_beh, c='k', ls='--', label='behavior', zorder=0)\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "62b4937e-484a-4740-90ea-8cf91fd19578", + "metadata": {}, + "source": [ + "## Load results" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "73d56e0a-de4a-4a90-9c3a-c4b78a02e64b", + "metadata": {}, + "outputs": [], + "source": [ + "df_results_0 = pd.read_csv(f'./results/{exp_name}-1/vaso_eps_{eps_str}-observed.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "d23aed3c-898f-4c44-b32b-20f744ab3a7d", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_orig_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-1/vaso_eps_{eps_str}-{pol_name}-orig.csv')\n", + " dfs_results_orig_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "074e313e-ccc4-475f-857b-12b83da8b48d", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotEval_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-2/vaso_eps_{eps_str}-{pol_name}-aug_step-annotEval.csv')\n", + " dfs_results_annotEval_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "0ec98102-a6c2-4521-a6ba-6267d7369cdf", + "metadata": {}, + "outputs": [], + "source": [ + "ratio_list = list(np.arange(0.0, 1.1, 0.1).round(1))" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "31b214a1-8930-422a-8b81-774425fa7278", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotEvalMissing_lists = []\n", + "for ratio in ratio_list:\n", + " dfs_results_annotEvalMissing_list = []\n", + " for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-4/vaso_eps_{eps_str}-{pol_name}-aug_step-annotEval-Noise_0.2-Missing_{ratio}.csv')\n", + " dfs_results_annotEvalMissing_list.append(df_results_)\n", + " dfs_results_annotEvalMissing_lists.append(dfs_results_annotEvalMissing_list)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "c82dd1c7-1d37-4398-8fb2-0a5eb3cb78b0", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotEvalMissingImpute_lists = []\n", + "for ratio in ratio_list:\n", + " dfs_results_annotEvalMissingImpute_list = []\n", + " for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-4/vaso_eps_{eps_str}-{pol_name}-aug_step-annotEval-Noise_0.2-MissingImpute_{ratio}.csv')\n", + " dfs_results_annotEvalMissingImpute_list.append(df_results_)\n", + " dfs_results_annotEvalMissingImpute_lists.append(dfs_results_annotEvalMissingImpute_list)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "1f8bff6a-1a52-4b51-8d6e-89c8c6ffd5c6", + "metadata": {}, + "source": [ + "## Plots" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "a0802fdd-7fe9-405b-9a13-2557a2065a11", + "metadata": {}, + "outputs": [], + "source": [ + "exp_idx = 13\n", + "π_eval = πs_flip_list[exp_idx]\n", + "J_eval = v_list[exp_idx]\n", + "df_results_orig = dfs_results_orig_list[exp_idx]\n", + "df_results_annotEval = dfs_results_annotEval_list[exp_idx]" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "a891204f-974c-4be3-8eb6-4c9446b67b7f", + "metadata": {}, + "outputs": [], + "source": [ + "df_results_annotEval_Missing = [dfs[exp_idx] for dfs in dfs_results_annotEvalMissing_lists]\n", + "df_results_annotEval_MissingImpute = [dfs[exp_idx] for dfs in dfs_results_annotEvalMissingImpute_lists]" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "652094bf-2cec-4f04-8c65-39d62b3a5831", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-30T11:27:40.780824\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(4,4))\n", + "plt.axhline(J_beh, c='k', ls='--', label='behavior', zorder=0)\n", + "plt.scatter(N_val, J_beh, marker='o', s=40, c='k', alpha=1)\n", + "plt.axhline(J_eval, c='k', ls=':', label='true', zorder=0)\n", + "plt.scatter([N_val]*runs, df_results_0['IS_value'], \n", + " marker='o', s=10, c='k', alpha=0.25, ec='none')\n", + "\n", + "for df, name, color in [\n", + " [df_results_orig, 'original', 'tab:blue'],\n", + " [df_results_annotEval, 'proposed', 'tab:green'],\n", + " [df_results_annotEval_Missing[5], 'proposed noise', 'tab:cyan'],\n", + " [df_results_annotEval_MissingImpute[5], 'proposed noise', 'blue'],\n", + "]:\n", + " plt.plot(0,0, c=color, label=name)\n", + " plt.scatter(df['ESS1'], df['IS_value'], marker='o', s=10, c=color, alpha=0.25, ec='none')\n", + " plt.scatter(iqm(df['ESS1']), iqm(df['IS_value']), marker='o', s=40, c=color, alpha=0.8)\n", + " plt.scatter(df['ESS1'], df['WIS_value'], marker='X', s=10, c=color, alpha=0.25, ec='none')\n", + " plt.scatter(iqm(df['ESS1']), iqm(df['WIS_value']), marker='X', s=40, c=color, alpha=0.8)\n", + "\n", + "plt.scatter(-100,0, c='gray', marker='o', label='IS')\n", + "plt.scatter(-100,0, c='gray', marker='X', label='WIS')\n", + "plt.xlabel('ESS')\n", + "plt.ylabel('OPE value')\n", + "# plt.ylim(0.1, 0.7)\n", + "plt.xlim(0, N_val*1.25)\n", + "plt.legend(bbox_to_anchor=(1.04, 1))\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "da761092-92b5-439c-aa25-9efd0ddb674a", + "metadata": {}, + "source": [ + "## Metrics" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "980a4d9d-038e-43e0-a19b-6d5cd8d4c1c1", + "metadata": {}, + "outputs": [], + "source": [ + "def rmse(y1, y2):\n", + " return np.sqrt(np.mean(np.square(y1-y2)))" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "949d31e2-c9c1-4f22-8169-4e8befa02ed2", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_confmat(J_eval, v_est_list):\n", + " confmat = np.zeros((2,2))\n", + " confmat[int(J_eval >= J_beh), 0] += (v_est_list < df_results_0['IS_value'].iloc[:len(v_est_list)]).sum()\n", + " confmat[int(J_eval >= J_beh), 1] += (v_est_list >= df_results_0['IS_value'].iloc[:len(v_est_list)]).sum()\n", + " return confmat" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "c5ee241c-90c4-4791-84e0-428443b3b6e2", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_confmat_batch(v_list, v_est_list, v_beh):\n", + " confmat = np.zeros((2,2))\n", + " for J_eval, v_est in zip(v_list, v_est_list):\n", + " confmat[int(J_eval >= J_beh), 0] += (v_est < v_beh).sum()\n", + " confmat[int(J_eval >= J_beh), 1] += (v_est >= v_beh).sum()\n", + " return confmat" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "ffd8f027-3605-4876-b6e6-4611ca95f1d3", + "metadata": {}, + "outputs": [], + "source": [ + "orig_v_IS = []\n", + "orig_v_WIS = []\n", + "orig_v_ESS = []\n", + "for π_eval, df_results_ in zip(πs_flip_list, dfs_results_orig_list):\n", + " orig_v_IS.append(df_results_['IS_value'])\n", + " orig_v_WIS.append(df_results_['WIS_value'])\n", + " orig_v_ESS.append(df_results_['ESS1'])" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "3b167971-fe17-49d5-9b0b-a137b40f2531", + "metadata": {}, + "outputs": [], + "source": [ + "annotEval_v_IS = []\n", + "annotEval_v_WIS = []\n", + "annotEval_v_ESS = []\n", + "for π_eval, df_results_ in zip(πs_flip_list, dfs_results_annotEval_list):\n", + " annotEval_v_IS.append(df_results_['IS_value'])\n", + " annotEval_v_WIS.append(df_results_['WIS_value'])\n", + " annotEval_v_ESS.append(df_results_['ESS1'])" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "2f553302-d7d2-4882-a030-d05ff0b75068", + "metadata": {}, + "outputs": [], + "source": [ + "orig_rmse_value = np.mean([rmse(l,v) for v,l in zip(v_list+v_list, orig_v_IS+orig_v_WIS)])\n", + "orig_spearman_corr = np.mean([scipy.stats.spearmanr(v_list+v_list, v_π_est_list).correlation for v_π_est_list in np.array(orig_v_IS+orig_v_WIS).T])" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "d507ed94-0cca-442f-802c-bf79707fd2b3", + "metadata": {}, + "outputs": [], + "source": [ + "oracle_rmse_value = np.mean([rmse(l,v) for v,l in zip(v_list+v_list, annotEval_v_IS+annotEval_v_WIS)])\n", + "oracle_spearman_corr = np.mean([scipy.stats.spearmanr(v_list+v_list, v_π_est_list).correlation for v_π_est_list in np.array(annotEval_v_IS+annotEval_v_WIS).T])" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "id": "3a44e898-b5d2-4f56-bda6-ceec6f90987a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.0\n", + "ESS: 76.82151971964208\n", + "RMSE: 0.113±0.038\n", + "Spearman: 0.596±0.110\n", + "Accuracy: 76.5%±3.5% \t FPR: 33.7%±8.7% \t FNR: 15.9%±4.6%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$729371
$$v(\\pi_e) \\geq v(\\pi_b)$$2391261
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.1\n", + "ESS: 81.08907981324657\n", + "RMSE: 0.067±0.028\n", + "Spearman: 0.823±0.067\n", + "Accuracy: 82.6%±4.6% \t FPR: 23.2%±9.4% \t FNR: 13.2%±6.9%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$845255
$$v(\\pi_e) \\geq v(\\pi_b)$$1981302
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.2\n", + "ESS: 94.72037116614614\n", + "RMSE: 0.047±0.011\n", + "Spearman: 0.903±0.035\n", + "Accuracy: 86.0%±4.1% \t FPR: 16.1%±8.8% \t FNR: 12.5%±8.2%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$923177
$$v(\\pi_e) \\geq v(\\pi_b)$$1881312
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.3\n", + "ESS: 115.70414289800716\n", + "RMSE: 0.041±0.012\n", + "Spearman: 0.928±0.037\n", + "Accuracy: 87.3%±4.3% \t FPR: 14.0%±9.8% \t FNR: 11.7%±9.1%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$946154
$$v(\\pi_e) \\geq v(\\pi_b)$$1761324
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.4\n", + "ESS: 142.24933589580476\n", + "RMSE: 0.035±0.008\n", + "Spearman: 0.951±0.018\n", + "Accuracy: 89.6%±4.4% \t FPR: 10.0%±9.5% \t FNR: 10.7%±9.6%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$990110
$$v(\\pi_e) \\geq v(\\pi_b)$$1611339
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.5\n", + "ESS: 179.07423874556713\n", + "RMSE: 0.033±0.012\n", + "Spearman: 0.958±0.031\n", + "Accuracy: 89.8%±4.0% \t FPR: 10.5%±9.9% \t FNR: 10.1%±8.7%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$985115
$$v(\\pi_e) \\geq v(\\pi_b)$$1511349
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.6\n", + "ESS: 238.77740059302315\n", + "RMSE: 0.028±0.008\n", + "Spearman: 0.972±0.013\n", + "Accuracy: 91.9%±4.0% \t FPR: 8.3%±9.5% \t FNR: 8.0%±7.9%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$100991
$$v(\\pi_e) \\geq v(\\pi_b)$$1201380
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.7\n", + "ESS: 316.1310533344366\n", + "RMSE: 0.024±0.007\n", + "Spearman: 0.980±0.009\n", + "Accuracy: 92.9%±3.9% \t FPR: 6.5%±9.3% \t FNR: 7.6%±6.7%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$102971
$$v(\\pi_e) \\geq v(\\pi_b)$$1141386
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.8\n", + "ESS: 428.489106328174\n", + "RMSE: 0.021±0.007\n", + "Spearman: 0.985±0.008\n", + "Accuracy: 94.9%±4.5% \t FPR: 5.8%±9.2% \t FNR: 4.5%±6.3%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$103664
$$v(\\pi_e) \\geq v(\\pi_b)$$681432
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.9\n", + "ESS: 635.9417124667175\n", + "RMSE: 0.016±0.006\n", + "Spearman: 0.991±0.005\n", + "Accuracy: 95.2%±3.8% \t FPR: 4.9%±7.9% \t FNR: 4.7%±6.0%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$104654
$$v(\\pi_e) \\geq v(\\pi_b)$$711429
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 1.0\n", + "ESS: 993.9926736346977\n", + "RMSE: 0.014±0.005\n", + "Spearman: 0.994±0.002\n", + "Accuracy: 95.3%±4.1% \t FPR: 4.4%±7.9% \t FNR: 4.9%±6.8%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$105248
$$v(\\pi_e) \\geq v(\\pi_b)$$731427
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "all_results_Missing = []\n", + "for ratio, df_results_lists in zip(ratio_list, dfs_results_annotEvalMissing_lists):\n", + " print('===', 'Ratio', ratio)\n", + " v_IS = []\n", + " v_WIS = []\n", + " v_ESS = []\n", + " for π_eval, df_results_ in zip(πs_flip_list, df_results_lists):\n", + " v_IS.append(df_results_['IS_value'])\n", + " v_WIS.append(df_results_['WIS_value'])\n", + " v_ESS.append(df_results_['ESS1'])\n", + " print('ESS:', np.mean(v_ESS))\n", + " \n", + " v_est_list = v_IS + v_WIS\n", + " rmse_value = [rmse(l, np.array(v_list+v_list)) for l in np.array(v_est_list).T]\n", + " print('RMSE: {:.3f}±{:.3f}'.format(np.mean(rmse_value).round(3), np.std(rmse_value).round(3)))\n", + " \n", + " spearman_corr = [scipy.stats.spearmanr(l, np.array(v_list+v_list)).correlation for l in np.array(v_est_list).T]\n", + " print('Spearman: {:.3f}±{:.3f}'.format(np.mean(spearman_corr).round(3), np.std(spearman_corr).round(3)))\n", + " \n", + " confmats_ = [compute_confmat_batch(v_list+v_list, l, vb) for l,vb in zip(np.array(v_est_list).T, df_results_0['IS_value'])]\n", + " confmat_sum = sum(confmats_)\n", + " (accuracy, fpr, fnr) = (\n", + " [(cm[0,0]+cm[1,1])/np.sum(cm) for cm in confmats_],\n", + " [cm[0,1]/(cm[0,0]+cm[0,1]) for cm in confmats_], \n", + " [cm[1,0]/(cm[1,0]+cm[1,1]) for cm in confmats_],\n", + " )\n", + " print('Accuracy: {:.1%}±{:.1%} \\t FPR: {:.1%}±{:.1%} \\t FNR: {:.1%}±{:.1%}'.format(\n", + " np.mean(accuracy), np.std(accuracy), \n", + " np.mean(fpr), np.std(fpr), \n", + " np.mean(fnr), np.std(fnr), \n", + " ))\n", + " display(pd.DataFrame(\n", + " (confmat_sum).astype(int), \n", + " index=['$$v(\\pi_e) < v(\\pi_b)$$', '$$v(\\pi_e) \\geq v(\\pi_b)$$'],\n", + " columns=['$$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$', '$$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$'],\n", + " ).rename_axis(index='True Ranking', columns=f'{name} Predicted Ranking')\\\n", + " .style.background_gradient(cmap='Blues', vmin=0, vmax=1500))\n", + " \n", + " all_results_Missing.append({\n", + " 'spearman': spearman_corr,\n", + " 'rmse': rmse_value,\n", + " 'accuracy': accuracy,\n", + " 'fpr': fpr,\n", + " 'fnr': fnr\n", + " })" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "8a0d6302-6680-4ce2-b183-72ce80bf12f6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.0\n", + "ESS: 76.82151971964208\n", + "RMSE: 0.113±0.038\n", + "Spearman: 0.596±0.110\n", + "Accuracy: 76.5%±3.5% \t FPR: 33.7%±8.7% \t FNR: 15.9%±4.6%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$729371
$$v(\\pi_e) \\geq v(\\pi_b)$$2391261
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.1\n", + "ESS: 662.5128750541377\n", + "RMSE: 0.033±0.008\n", + "Spearman: 0.976±0.006\n", + "Accuracy: 87.3%±5.6% \t FPR: 1.3%±3.6% \t FNR: 21.1%±10.7%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$108614
$$v(\\pi_e) \\geq v(\\pi_b)$$3171183
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.2\n", + "ESS: 771.0996409286839\n", + "RMSE: 0.030±0.009\n", + "Spearman: 0.984±0.005\n", + "Accuracy: 89.0%±5.7% \t FPR: 0.8%±2.4% \t FNR: 18.5%±10.6%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$10919
$$v(\\pi_e) \\geq v(\\pi_b)$$2771223
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.3\n", + "ESS: 835.3671529861018\n", + "RMSE: 0.026±0.010\n", + "Spearman: 0.988±0.004\n", + "Accuracy: 90.3%±5.8% \t FPR: 0.8%±2.7% \t FNR: 16.1%±10.5%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$10919
$$v(\\pi_e) \\geq v(\\pi_b)$$2421258
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.4\n", + "ESS: 875.5712773230858\n", + "RMSE: 0.024±0.009\n", + "Spearman: 0.990±0.004\n", + "Accuracy: 91.0%±6.4% \t FPR: 0.8%±2.4% \t FNR: 15.1%±11.6%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$10919
$$v(\\pi_e) \\geq v(\\pi_b)$$2261274
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.5\n", + "ESS: 908.3580069939785\n", + "RMSE: 0.021±0.009\n", + "Spearman: 0.992±0.004\n", + "Accuracy: 92.5%±5.1% \t FPR: 1.5%±3.8% \t FNR: 11.8%±9.6%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$108317
$$v(\\pi_e) \\geq v(\\pi_b)$$1771323
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.6\n", + "ESS: 932.1812568169127\n", + "RMSE: 0.019±0.008\n", + "Spearman: 0.993±0.003\n", + "Accuracy: 93.5%±5.0% \t FPR: 1.9%±4.0% \t FNR: 9.9%±9.6%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$107921
$$v(\\pi_e) \\geq v(\\pi_b)$$1481352
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.7\n", + "ESS: 951.5218790770941\n", + "RMSE: 0.017±0.007\n", + "Spearman: 0.993±0.003\n", + "Accuracy: 94.4%±4.7% \t FPR: 2.1%±4.8% \t FNR: 8.1%±8.7%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$107723
$$v(\\pi_e) \\geq v(\\pi_b)$$1221378
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.8\n", + "ESS: 968.6742665551586\n", + "RMSE: 0.016±0.006\n", + "Spearman: 0.994±0.003\n", + "Accuracy: 94.8%±4.0% \t FPR: 3.2%±5.8% \t FNR: 6.6%±7.6%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$106535
$$v(\\pi_e) \\geq v(\\pi_b)$$991401
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 0.9\n", + "ESS: 982.2525313784963\n", + "RMSE: 0.014±0.007\n", + "Spearman: 0.994±0.004\n", + "Accuracy: 95.5%±3.5% \t FPR: 3.7%±6.5% \t FNR: 5.1%±6.2%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$105941
$$v(\\pi_e) \\geq v(\\pi_b)$$761424
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Ratio 1.0\n", + "ESS: 993.9926736346977\n", + "RMSE: 0.014±0.005\n", + "Spearman: 0.994±0.002\n", + "Accuracy: 95.3%±4.1% \t FPR: 4.4%±7.9% \t FNR: 4.9%±6.8%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$105248
$$v(\\pi_e) \\geq v(\\pi_b)$$731427
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "all_results_MissingImpute = []\n", + "for ratio, df_results_lists in zip(ratio_list, dfs_results_annotEvalMissingImpute_lists):\n", + " print('===', 'Ratio', ratio)\n", + " v_IS = []\n", + " v_WIS = []\n", + " v_ESS = []\n", + " for π_eval, df_results_ in zip(πs_flip_list, df_results_lists):\n", + " v_IS.append(df_results_['IS_value'])\n", + " v_WIS.append(df_results_['WIS_value'])\n", + " v_ESS.append(df_results_['ESS1'])\n", + " print('ESS:', np.mean(v_ESS))\n", + " \n", + " v_est_list = v_IS + v_WIS\n", + " rmse_value = [rmse(l, np.array(v_list+v_list)) for l in np.array(v_est_list).T]\n", + " print('RMSE: {:.3f}±{:.3f}'.format(np.mean(rmse_value).round(3), np.std(rmse_value).round(3)))\n", + " \n", + " spearman_corr = [scipy.stats.spearmanr(l, np.array(v_list+v_list)).correlation for l in np.array(v_est_list).T]\n", + " print('Spearman: {:.3f}±{:.3f}'.format(np.mean(spearman_corr).round(3), np.std(spearman_corr).round(3)))\n", + " \n", + " confmats_ = [compute_confmat_batch(v_list+v_list, l, vb) for l,vb in zip(np.array(v_est_list).T, df_results_0['IS_value'])]\n", + " confmat_sum = sum(confmats_)\n", + " (accuracy, fpr, fnr) = (\n", + " [(cm[0,0]+cm[1,1])/np.sum(cm) for cm in confmats_],\n", + " [cm[0,1]/(cm[0,0]+cm[0,1]) for cm in confmats_], \n", + " [cm[1,0]/(cm[1,0]+cm[1,1]) for cm in confmats_],\n", + " )\n", + " print('Accuracy: {:.1%}±{:.1%} \\t FPR: {:.1%}±{:.1%} \\t FNR: {:.1%}±{:.1%}'.format(\n", + " np.mean(accuracy), np.std(accuracy), \n", + " np.mean(fpr), np.std(fpr), \n", + " np.mean(fnr), np.std(fnr), \n", + " ))\n", + " display(pd.DataFrame(\n", + " (confmat_sum).astype(int), \n", + " index=['$$v(\\pi_e) < v(\\pi_b)$$', '$$v(\\pi_e) \\geq v(\\pi_b)$$'],\n", + " columns=['$$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$', '$$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$'],\n", + " ).rename_axis(index='True Ranking', columns=f'{name} Predicted Ranking')\\\n", + " .style.background_gradient(cmap='Blues', vmin=0, vmax=1500))\n", + " \n", + " all_results_MissingImpute.append({\n", + " 'spearman': spearman_corr,\n", + " 'rmse': rmse_value,\n", + " 'accuracy': accuracy,\n", + " 'fpr': fpr,\n", + " 'fnr': fnr\n", + " })" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "44db5686-b452-49e2-b97e-77047616022b", + "metadata": {}, + "outputs": [], + "source": [ + "df_all_results_Missing = pd.DataFrame(all_results_Missing, index=ratio_list)\n", + "df_all_results_Missing.index.name = 'ratio'\n", + "df_all_results_MissingImpute = pd.DataFrame(all_results_MissingImpute, index=ratio_list)\n", + "df_all_results_MissingImpute.index.name = 'ratio'" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "f41ca0f6-0c3b-4f07-bc94-f1795d70aca9", + "metadata": {}, + "outputs": [], + "source": [ + "df_all_results_Missing_median = df_all_results_Missing.applymap(np.median)\n", + "df_all_results_Missing_mean = df_all_results_Missing.applymap(np.mean)\n", + "df_all_results_Missing_std = df_all_results_Missing.applymap(np.std)\n", + "df_all_results_MissingImpute_median = df_all_results_MissingImpute.applymap(np.median)\n", + "df_all_results_MissingImpute_mean = df_all_results_MissingImpute.applymap(np.mean)\n", + "df_all_results_MissingImpute_std = df_all_results_MissingImpute.applymap(np.std)" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "7130e67e-4437-41f1-85b1-ba5d03f1d55f", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-05-08T11:56:29.778119\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(3,1, sharex=True, figsize=(2.5,3))\n", + "df_all_results_Missing_mean['rmse'].plot(ls=(0,(1,1)), c='r', ax=ax[0])\n", + "df_all_results_Missing_mean['spearman'].plot(ls=(0,(1,1)), c='r', ax=ax[1])\n", + "df_all_results_Missing_mean['accuracy'].plot(ls=(0,(1,1)), c='r', ax=ax[2])\n", + "df_all_results_MissingImpute_mean['rmse'].plot(ls='--', c='g', ax=ax[0])\n", + "df_all_results_MissingImpute_mean['spearman'].plot(ls='--', c='g', ax=ax[1])\n", + "df_all_results_MissingImpute_mean['accuracy'].plot(ls='--', c='g', ax=ax[2])\n", + "ax[0].fill_between(ratio_list, \n", + " df_all_results_Missing_mean['rmse']-df_all_results_Missing_std['rmse'], \n", + " df_all_results_Missing_mean['rmse']+df_all_results_Missing_std['rmse'], fc='r', alpha=0.2,\n", + " )\n", + "ax[1].fill_between(ratio_list, \n", + " df_all_results_Missing_mean['spearman']-df_all_results_Missing_std['spearman'], \n", + " df_all_results_Missing_mean['spearman']+df_all_results_Missing_std['spearman'], fc='r', alpha=0.2,\n", + " )\n", + "ax[2].fill_between(ratio_list, \n", + " df_all_results_Missing_mean['accuracy']-df_all_results_Missing_std['accuracy'], \n", + " df_all_results_Missing_mean['accuracy']+df_all_results_Missing_std['accuracy'], fc='r', alpha=0.2,\n", + " )\n", + "ax[0].fill_between(ratio_list, \n", + " df_all_results_MissingImpute_mean['rmse']-df_all_results_MissingImpute_std['rmse'], \n", + " df_all_results_MissingImpute_mean['rmse']+df_all_results_MissingImpute_std['rmse'], fc='g', alpha=0.2,\n", + " )\n", + "ax[1].fill_between(ratio_list, \n", + " df_all_results_MissingImpute_mean['spearman']-df_all_results_MissingImpute_std['spearman'], \n", + " df_all_results_MissingImpute_mean['spearman']+df_all_results_MissingImpute_std['spearman'], fc='g', alpha=0.2,\n", + " )\n", + "ax[2].fill_between(ratio_list, \n", + " df_all_results_MissingImpute_mean['accuracy']-df_all_results_MissingImpute_std['accuracy'], \n", + " df_all_results_MissingImpute_mean['accuracy']+df_all_results_MissingImpute_std['accuracy'], fc='g', alpha=0.2,\n", + " )\n", + "ax[0].plot(1, oracle_rmse_value, marker='*', c='k')\n", + "ax[1].plot(1, oracle_spearman_corr, marker='*', c='k')\n", + "ax[2].plot(1, 0.957, marker='*', c='k')\n", + "ax[0].axhline(orig_rmse_value, c='gray', ls=(0,(0.5,0.5)))\n", + "ax[1].axhline(orig_spearman_corr, c='gray', ls=(0,(0.5,0.5)))\n", + "ax[2].axhline(0.765, c='gray', ls=(0,(0.5,0.5)))\n", + "ax[0].set_ylabel('RMSE')\n", + "ax[1].set_ylabel('Spearman\\nCorr.')\n", + "ax[2].set_ylabel('Bin. Class.\\nAccuracy')\n", + "ax[0].set_ylim(0, 0.15)\n", + "ax[1].set_ylim(0.45, 1.05)\n", + "ax[2].set_ylim(0.7, 1.05)\n", + "plt.xlabel('fraction of annotated samples')\n", + "fig.align_labels()\n", + "# plt.figlegend(\n", + "# [matplotlib.lines.Line2D([], [], c='gray', ls=':'), \n", + "# matplotlib.lines.Line2D([], [], c='k', marker='*', linestyle='none'),\n", + "# matplotlib.lines.Line2D([], [], c='r', ls='--'), \n", + "# matplotlib.lines.Line2D([], [], c='g'), \n", + "# ],\n", + "# ['baseline', 'ideal setting', 'w/o imputation', 'w/ imputation', ],\n", + "# loc='center', bbox_to_anchor=(0.5,1.0),\n", + "# ncol=2, handlelength=1.333, handletextpad=0.6, columnspacing=1,\n", + "# )\n", + "plt.savefig('fig/annotation_missing.pdf', bbox_inches='tight')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02f5e299-2124-4f79-a844-e20f5953d585", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sepsisSim/experiments/plots-noisy-v2.ipynb b/sepsisSim/experiments/plots-noisy-v2.ipynb new file mode 100644 index 0000000..aee7a55 --- /dev/null +++ b/sepsisSim/experiments/plots-noisy-v2.ipynb @@ -0,0 +1,6743 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "89b14dad-9e23-4f54-8bdf-dd5eebd8c025", + "metadata": {}, + "outputs": [], + "source": [ + "# ## Simulation parameters\n", + "exp_name = 'exp-FINAL'\n", + "eps = 0.10\n", + "eps_str = '0_1'\n", + "\n", + "run_idx_length = 1_000\n", + "N_val = 1_000\n", + "runs = 50" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a510145e-f9df-4460-a15d-377f2b4a6072", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib\n", + "matplotlib.rcParams['font.sans-serif'] = ['Arial']\n", + "%config InlineBackend.figure_formats = ['svg']" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "11339963-17b7-44f7-9823-59ac511a4a8a", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "from collections import defaultdict\n", + "import pickle\n", + "import itertools\n", + "import copy\n", + "import random\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import scipy.stats\n", + "from sklearn import metrics\n", + "import itertools\n", + "\n", + "import joblib\n", + "from joblib import Parallel, delayed" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0d61671d-b32e-4ec5-8859-85a5fae5c3c4", + "metadata": {}, + "outputs": [], + "source": [ + "from OPE_utils_new import (\n", + " format_data_tensor,\n", + " policy_eval_analytic_finite,\n", + " OPE_IS_h,\n", + " compute_behavior_policy_h,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0c573be0-9afb-4dc0-b018-4c62a0ce15f7", + "metadata": {}, + "outputs": [], + "source": [ + "def policy_eval_helper(π):\n", + " V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H)\n", + " Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R]\n", + " J = isd @ V_H[0]\n", + " # Check recursive relationships\n", + " assert len(Q_H) == H\n", + " assert len(V_H) == H\n", + " assert np.all(Q_H[-1] == R)\n", + " assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1])\n", + " assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2])\n", + " return V_H, Q_H, J" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a0ab1c7c-bcd3-43ad-848b-93268fa1a2b6", + "metadata": {}, + "outputs": [], + "source": [ + "def iqm(x):\n", + " return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "467e2e02-2f18-4fa9-9d17-3b9ae0e4b8dd", + "metadata": {}, + "outputs": [], + "source": [ + "NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP\n", + "G_min = -1 # the minimum possible return\n", + "G_max = 1 # the maximum possible return\n", + "nS, nA = 1442, 8\n", + "\n", + "PROB_DIAB = 0.2" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fdf1dfc0-7d33-4f7d-952c-5b77668f2968", + "metadata": {}, + "outputs": [], + "source": [ + "# Ground truth MDP model\n", + "MDP_parameters = joblib.load('../data/MDP_parameters.joblib')\n", + "P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next)\n", + "R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A)\n", + "nS, nA = R.shape\n", + "gamma = 0.99\n", + "\n", + "# unif rand isd, mixture of diabetic state\n", + "isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib')\n", + "isd = (isd > 0).astype(float)\n", + "isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB)\n", + "isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c210ea9b-a6e0-42e6-be21-35a8e439c0c8", + "metadata": {}, + "outputs": [], + "source": [ + "# Precomputed optimal policy\n", + "π_star = joblib.load('../data/π_star.joblib')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "9342ffbd-00f1-4916-b418-054a289f20bd", + "metadata": { + "tags": [] + }, + "source": [ + "## Policies" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "645c4315-1109-4b41-b7ec-177e03e3343d", + "metadata": {}, + "outputs": [], + "source": [ + "# vaso unif, mv abx optimal\n", + "π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "32ece6ab-6dca-479a-ab85-0cc3233e9f72", + "metadata": {}, + "source": [ + "### Behavior policy" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "639c3b25-c715-4a90-8d1a-094e4478b9c9", + "metadata": {}, + "outputs": [], + "source": [ + "# vaso eps=0.5, mv abx optimal\n", + "π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)\n", + "π_beh[π_star == 1] = 1-eps\n", + "π_beh[π_beh == 0.5] = eps" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "50a792f2-1f9d-40f7-89db-bd9245eabdf3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.25038354793851164" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh)\n", + "J_beh" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "8d88637a-8298-4f1f-88ef-ab11b05ec08b", + "metadata": {}, + "source": [ + "### Optimal policy" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "d849f76e-01a1-48f8-9d8a-e9d0a880cadf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.40877179296760224" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V_H_star, Q_H_star, J_star = policy_eval_helper(π_star)\n", + "J_star" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "44136cb8-d2e0-459d-8ee1-f15b2ed8a62c", + "metadata": {}, + "source": [ + "### flip action for x% states" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "97e67e52-09c6-4f0d-b435-1304a88a98d6", + "metadata": {}, + "outputs": [], + "source": [ + "def get_π_flip(pol_flip_seed, pol_flip_num):\n", + " rng_flip = np.random.default_rng(pol_flip_seed)\n", + " flip_states = rng_flip.choice(range(1440), pol_flip_num, replace=False)\n", + "\n", + " π_tmp = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)\n", + " π_flip = π_tmp.copy()\n", + " π_flip[π_tmp == 0.5] = 0\n", + " π_flip[π_star == 1] = 1\n", + " for s in flip_states:\n", + " π_flip[s, π_tmp[s] == 0.5] = 1\n", + " π_flip[s, π_star[s] == 1] = 0\n", + " assert π_flip.sum(axis=1).mean() == 1\n", + " return π_flip" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "67c8441f-07bc-425a-ad5b-26b970653986", + "metadata": {}, + "outputs": [], + "source": [ + "πs_flip_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " π_flip = get_π_flip(flip_seed, flip_num)\n", + " πs_flip_list.append(π_flip)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "976a7213-6cb1-4511-a888-80173c967581", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████████████████████████████████████████████████████████████████████████████████| 26/26 [00:21<00:00, 1.21it/s]\n" + ] + } + ], + "source": [ + "v_list = []\n", + "for π_eval in tqdm(πs_flip_list):\n", + " _, _, J_eval = policy_eval_helper(π_eval)\n", + " v_list.append(J_eval)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "c421cb40-2069-4d04-be56-2456dc6810b4", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-05-05T13:24:57.339033\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(v_list, ls='none', marker='.')\n", + "plt.axhline(J_beh, c='k', ls='--', label='behavior', zorder=0)\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "62b4937e-484a-4740-90ea-8cf91fd19578", + "metadata": {}, + "source": [ + "## Load results" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "73d56e0a-de4a-4a90-9c3a-c4b78a02e64b", + "metadata": {}, + "outputs": [], + "source": [ + "df_results_0 = pd.read_csv(f'./results/{exp_name}-1/vaso_eps_{eps_str}-observed.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "d23aed3c-898f-4c44-b32b-20f744ab3a7d", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_orig_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-1/vaso_eps_{eps_str}-{pol_name}-orig.csv')\n", + " dfs_results_orig_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "074e313e-ccc4-475f-857b-12b83da8b48d", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotEval_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-2/vaso_eps_{eps_str}-{pol_name}-aug_step-annotEval.csv')\n", + " dfs_results_annotEval_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "0ec98102-a6c2-4521-a6ba-6267d7369cdf", + "metadata": {}, + "outputs": [], + "source": [ + "noise_list = list(np.arange(0.0, 1.1, 0.1).round(1))" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "31b214a1-8930-422a-8b81-774425fa7278", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotEvalNoiseMissing_lists = []\n", + "for noise in noise_list:\n", + " dfs_results_annotEvalNoise_list = []\n", + " for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-4/vaso_eps_{eps_str}-{pol_name}-aug_step-annotEval-Noise_{noise}-Missing_0.1.csv')\n", + " dfs_results_annotEvalNoise_list.append(df_results_)\n", + " dfs_results_annotEvalNoiseMissing_lists.append(dfs_results_annotEvalNoise_list)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "9a81ac30-a84b-40a2-9cf1-dc22ae08f6bb", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotEvalNoiseMissingImpute_lists = []\n", + "for noise in noise_list:\n", + " dfs_results_annotEvalNoise_list = []\n", + " for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-4/vaso_eps_{eps_str}-{pol_name}-aug_step-annotEval-Noise_{noise}-MissingImpute_0.1.csv')\n", + " dfs_results_annotEvalNoise_list.append(df_results_)\n", + " dfs_results_annotEvalNoiseMissingImpute_lists.append(dfs_results_annotEvalNoise_list)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "1f8bff6a-1a52-4b51-8d6e-89c8c6ffd5c6", + "metadata": {}, + "source": [ + "## Plots" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "a0802fdd-7fe9-405b-9a13-2557a2065a11", + "metadata": {}, + "outputs": [], + "source": [ + "exp_idx = 13\n", + "π_eval = πs_flip_list[exp_idx]\n", + "J_eval = v_list[exp_idx]\n", + "df_results_orig = dfs_results_orig_list[exp_idx]\n", + "df_results_annotEval = dfs_results_annotEval_list[exp_idx]" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "a891204f-974c-4be3-8eb6-4c9446b67b7f", + "metadata": {}, + "outputs": [], + "source": [ + "df_results_annotEval_NoiseMissing = [dfs[exp_idx] for dfs in dfs_results_annotEvalNoiseMissing_lists]\n", + "df_results_annotEval_NoiseMissingImpute = [dfs[exp_idx] for dfs in dfs_results_annotEvalNoiseMissingImpute_lists]" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "652094bf-2cec-4f04-8c65-39d62b3a5831", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-05-05T13:27:32.479959\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(4,4))\n", + "plt.axhline(J_beh, c='k', ls='--', label='behavior', zorder=0)\n", + "plt.scatter(N_val, J_beh, marker='o', s=40, c='k', alpha=1)\n", + "plt.axhline(J_eval, c='k', ls=':', label='true', zorder=0)\n", + "plt.scatter([N_val]*runs, df_results_0['IS_value'], \n", + " marker='o', s=10, c='k', alpha=0.25, ec='none')\n", + "\n", + "for df, name, color in [\n", + " [df_results_orig, 'original', 'tab:blue'],\n", + " [df_results_annotEval, 'proposed', 'tab:green'],\n", + " [df_results_annotEval_NoiseMissing[-1], 'proposed noise', 'tab:cyan'],\n", + " [df_results_annotEval_NoiseMissingImpute[-1], 'proposed noise', 'blue'],\n", + "]:\n", + " plt.plot(0,0, c=color, label=name)\n", + " plt.scatter(df['ESS1'], df['IS_value'], marker='o', s=10, c=color, alpha=0.25, ec='none')\n", + " plt.scatter(iqm(df['ESS1']), iqm(df['IS_value']), marker='o', s=40, c=color, alpha=0.8)\n", + " plt.scatter(df['ESS1'], df['WIS_value'], marker='X', s=10, c=color, alpha=0.25, ec='none')\n", + " plt.scatter(iqm(df['ESS1']), iqm(df['WIS_value']), marker='X', s=40, c=color, alpha=0.8)\n", + "\n", + "plt.scatter(-100,0, c='gray', marker='o', label='IS')\n", + "plt.scatter(-100,0, c='gray', marker='X', label='WIS')\n", + "plt.xlabel('ESS')\n", + "plt.ylabel('OPE value')\n", + "# plt.ylim(0.1, 0.7)\n", + "plt.xlim(0, N_val*1.25)\n", + "plt.legend(bbox_to_anchor=(1.04, 1))\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "da761092-92b5-439c-aa25-9efd0ddb674a", + "metadata": {}, + "source": [ + "## Metrics" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "980a4d9d-038e-43e0-a19b-6d5cd8d4c1c1", + "metadata": {}, + "outputs": [], + "source": [ + "def rmse(y1, y2):\n", + " return np.sqrt(np.mean(np.square(y1-y2)))" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "949d31e2-c9c1-4f22-8169-4e8befa02ed2", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_confmat(J_eval, v_est_list):\n", + " confmat = np.zeros((2,2))\n", + " confmat[int(J_eval >= J_beh), 0] += (v_est_list < df_results_0['IS_value'].iloc[:len(v_est_list)]).sum()\n", + " confmat[int(J_eval >= J_beh), 1] += (v_est_list >= df_results_0['IS_value'].iloc[:len(v_est_list)]).sum()\n", + " return confmat" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "fd117d72-4840-4510-a2d9-9013ef4a27ee", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_confmat_batch(v_list, v_est_list, v_beh):\n", + " confmat = np.zeros((2,2))\n", + " for J_eval, v_est in zip(v_list, v_est_list):\n", + " confmat[int(J_eval >= J_beh), 0] += (v_est < v_beh).sum()\n", + " confmat[int(J_eval >= J_beh), 1] += (v_est >= v_beh).sum()\n", + " return confmat" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "ffd8f027-3605-4876-b6e6-4611ca95f1d3", + "metadata": {}, + "outputs": [], + "source": [ + "orig_v_IS = []\n", + "orig_v_WIS = []\n", + "orig_v_ESS = []\n", + "for π_eval, df_results_ in zip(πs_flip_list, dfs_results_orig_list):\n", + " orig_v_IS.append(df_results_['IS_value'])\n", + " orig_v_WIS.append(df_results_['WIS_value'])\n", + " orig_v_ESS.append(df_results_['ESS1'])" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "3b167971-fe17-49d5-9b0b-a137b40f2531", + "metadata": {}, + "outputs": [], + "source": [ + "annotEval_v_IS = []\n", + "annotEval_v_WIS = []\n", + "annotEval_v_ESS = []\n", + "for π_eval, df_results_ in zip(πs_flip_list, dfs_results_annotEval_list):\n", + " annotEval_v_IS.append(df_results_['IS_value'])\n", + " annotEval_v_WIS.append(df_results_['WIS_value'])\n", + " annotEval_v_ESS.append(df_results_['ESS1'])" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "2f553302-d7d2-4882-a030-d05ff0b75068", + "metadata": {}, + "outputs": [], + "source": [ + "orig_rmse_value = np.mean([rmse(l,v) for v,l in zip(v_list+v_list, orig_v_IS+orig_v_WIS)])\n", + "orig_spearman_corr = np.mean([scipy.stats.spearmanr(v_list+v_list, v_π_est_list).correlation for v_π_est_list in np.array(orig_v_IS+orig_v_WIS).T])" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "d507ed94-0cca-442f-802c-bf79707fd2b3", + "metadata": {}, + "outputs": [], + "source": [ + "oracle_rmse_value = np.mean([rmse(l,v) for v,l in zip(v_list+v_list, annotEval_v_IS+annotEval_v_WIS)])\n", + "oracle_spearman_corr = np.mean([scipy.stats.spearmanr(v_list+v_list, v_π_est_list).correlation for v_π_est_list in np.array(annotEval_v_IS+annotEval_v_WIS).T])" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "4a27d8cd-3927-409f-a7e1-dd1d4b0ab5f8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.0\n", + "ESS: 81.08907981324657\n", + "RMSE: 0.063±0.018\n", + "Spearman: 0.837±0.064\n", + "Accuracy: 82.9%±4.2% \t FPR: 22.3%±9.4% \t FNR: 13.3%±6.8%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$855245
$$v(\\pi_e) \\geq v(\\pi_b)$$1991301
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.1\n", + "ESS: 81.08907981324657\n", + "RMSE: 0.064±0.023\n", + "Spearman: 0.833±0.064\n", + "Accuracy: 83.0%±4.3% \t FPR: 22.6%±9.3% \t FNR: 12.9%±6.8%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$851249
$$v(\\pi_e) \\geq v(\\pi_b)$$1931307
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.2\n", + "ESS: 81.08907981324657\n", + "RMSE: 0.067±0.028\n", + "Spearman: 0.823±0.067\n", + "Accuracy: 82.6%±4.6% \t FPR: 23.2%±9.4% \t FNR: 13.2%±6.9%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$845255
$$v(\\pi_e) \\geq v(\\pi_b)$$1981302
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.3\n", + "ESS: 81.08907981324657\n", + "RMSE: 0.070±0.033\n", + "Spearman: 0.811±0.074\n", + "Accuracy: 81.9%±5.0% \t FPR: 24.4%±10.8% \t FNR: 13.5%±6.8%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$832268
$$v(\\pi_e) \\geq v(\\pi_b)$$2021298
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.4\n", + "ESS: 81.08907981324657\n", + "RMSE: 0.075±0.039\n", + "Spearman: 0.795±0.087\n", + "Accuracy: 81.7%±5.0% \t FPR: 24.9%±10.4% \t FNR: 13.5%±6.8%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$826274
$$v(\\pi_e) \\geq v(\\pi_b)$$2021298
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.5\n", + "ESS: 81.08907981324657\n", + "RMSE: 0.080±0.045\n", + "Spearman: 0.777±0.101\n", + "Accuracy: 80.9%±5.2% \t FPR: 25.8%±10.7% \t FNR: 14.2%±7.3%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$816284
$$v(\\pi_e) \\geq v(\\pi_b)$$2131287
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.6\n", + "ESS: 81.08907981324657\n", + "RMSE: 0.085±0.052\n", + "Spearman: 0.761±0.114\n", + "Accuracy: 80.5%±5.0% \t FPR: 26.5%±10.8% \t FNR: 14.4%±7.4%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$809291
$$v(\\pi_e) \\geq v(\\pi_b)$$2161284
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.7\n", + "ESS: 81.08907981324657\n", + "RMSE: 0.091±0.058\n", + "Spearman: 0.746±0.124\n", + "Accuracy: 80.2%±5.2% \t FPR: 27.2%±11.8% \t FNR: 14.5%±7.5%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$801299
$$v(\\pi_e) \\geq v(\\pi_b)$$2171283
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.8\n", + "ESS: 81.08907981324657\n", + "RMSE: 0.098±0.065\n", + "Spearman: 0.730±0.132\n", + "Accuracy: 79.6%±5.4% \t FPR: 27.5%±12.0% \t FNR: 15.2%±7.8%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$797303
$$v(\\pi_e) \\geq v(\\pi_b)$$2281272
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.9\n", + "ESS: 81.08907981324657\n", + "RMSE: 0.104±0.072\n", + "Spearman: 0.713±0.141\n", + "Accuracy: 79.1%±5.4% \t FPR: 28.1%±12.5% \t FNR: 15.6%±8.1%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$791309
$$v(\\pi_e) \\geq v(\\pi_b)$$2341266
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 1.0\n", + "ESS: 81.08907981324657\n", + "RMSE: 0.111±0.078\n", + "Spearman: 0.695±0.149\n", + "Accuracy: 78.4%±5.7% \t FPR: 29.6%±13.7% \t FNR: 15.7%±8.4%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$774326
$$v(\\pi_e) \\geq v(\\pi_b)$$2361264
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "all_results_Missing = []\n", + "for noise, df_results_lists in zip(noise_list, dfs_results_annotEvalNoiseMissing_lists):\n", + " print('===', 'Noise', noise)\n", + " v_IS = []\n", + " v_WIS = []\n", + " v_ESS = []\n", + " for π_eval, df_results_ in zip(πs_flip_list, df_results_lists):\n", + " v_IS.append(df_results_['IS_value'])\n", + " v_WIS.append(df_results_['WIS_value'])\n", + " v_ESS.append(df_results_['ESS1'])\n", + " print('ESS:', np.mean(v_ESS))\n", + " \n", + " v_est_list = v_IS + v_WIS\n", + " rmse_value = [rmse(l, np.array(v_list+v_list)) for l in np.array(v_est_list).T]\n", + " print('RMSE: {:.3f}±{:.3f}'.format(np.mean(rmse_value).round(3), np.std(rmse_value).round(3)))\n", + " \n", + " spearman_corr = [scipy.stats.spearmanr(l, np.array(v_list+v_list)).correlation for l in np.array(v_est_list).T]\n", + " print('Spearman: {:.3f}±{:.3f}'.format(np.mean(spearman_corr).round(3), np.std(spearman_corr).round(3)))\n", + " \n", + " confmats_ = [compute_confmat_batch(v_list+v_list, l, vb) for l,vb in zip(np.array(v_est_list).T, df_results_0['IS_value'])]\n", + " confmat_sum = sum(confmats_)\n", + " (accuracy, fpr, fnr) = (\n", + " [(cm[0,0]+cm[1,1])/np.sum(cm) for cm in confmats_],\n", + " [cm[0,1]/(cm[0,0]+cm[0,1]) for cm in confmats_], \n", + " [cm[1,0]/(cm[1,0]+cm[1,1]) for cm in confmats_],\n", + " )\n", + " print('Accuracy: {:.1%}±{:.1%} \\t FPR: {:.1%}±{:.1%} \\t FNR: {:.1%}±{:.1%}'.format(\n", + " np.mean(accuracy), np.std(accuracy), \n", + " np.mean(fpr), np.std(fpr), \n", + " np.mean(fnr), np.std(fnr), \n", + " ))\n", + " display(pd.DataFrame(\n", + " (confmat_sum).astype(int), \n", + " index=['$$v(\\pi_e) < v(\\pi_b)$$', '$$v(\\pi_e) \\geq v(\\pi_b)$$'],\n", + " columns=['$$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$', '$$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$'],\n", + " ).rename_axis(index='True Ranking', columns=f'{name} Predicted Ranking')\\\n", + " .style.background_gradient(cmap='Blues', vmin=0, vmax=1500))\n", + " \n", + " all_results_Missing.append({\n", + " 'spearman': spearman_corr,\n", + " 'rmse': rmse_value,\n", + " 'accuracy': accuracy,\n", + " 'fpr': fpr,\n", + " 'fnr': fnr\n", + " })" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "7ee8de7f-eef2-4015-8199-e003dd9c1af4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.0\n", + "ESS: 662.5128750541377\n", + "RMSE: 0.032±0.008\n", + "Spearman: 0.979±0.005\n", + "Accuracy: 87.7%±5.7% \t FPR: 0.8%±2.5% \t FNR: 20.7%±10.4%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$10919
$$v(\\pi_e) \\geq v(\\pi_b)$$3111189
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.1\n", + "ESS: 662.5128750541377\n", + "RMSE: 0.032±0.008\n", + "Spearman: 0.978±0.005\n", + "Accuracy: 87.5%±5.4% \t FPR: 0.8%±2.5% \t FNR: 21.0%±10.0%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$10919
$$v(\\pi_e) \\geq v(\\pi_b)$$3151185
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.2\n", + "ESS: 662.5128750541377\n", + "RMSE: 0.033±0.008\n", + "Spearman: 0.976±0.006\n", + "Accuracy: 87.3%±5.6% \t FPR: 1.3%±3.6% \t FNR: 21.1%±10.7%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$108614
$$v(\\pi_e) \\geq v(\\pi_b)$$3171183
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.3\n", + "ESS: 662.5128750541377\n", + "RMSE: 0.034±0.009\n", + "Spearman: 0.973±0.007\n", + "Accuracy: 87.0%±5.7% \t FPR: 2.1%±5.7% \t FNR: 21.0%±11.3%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$107723
$$v(\\pi_e) \\geq v(\\pi_b)$$3151185
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.4\n", + "ESS: 662.5128750541377\n", + "RMSE: 0.036±0.010\n", + "Spearman: 0.968±0.008\n", + "Accuracy: 86.5%±6.0% \t FPR: 3.1%±6.8% \t FNR: 21.1%±12.1%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$106634
$$v(\\pi_e) \\geq v(\\pi_b)$$3161184
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.5\n", + "ESS: 662.5128750541377\n", + "RMSE: 0.039±0.011\n", + "Spearman: 0.962±0.011\n", + "Accuracy: 85.9%±6.3% \t FPR: 3.8%±7.7% \t FNR: 21.7%±13.2%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$105842
$$v(\\pi_e) \\geq v(\\pi_b)$$3251175
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.6\n", + "ESS: 662.5128750541377\n", + "RMSE: 0.042±0.012\n", + "Spearman: 0.956±0.013\n", + "Accuracy: 85.2%±6.5% \t FPR: 4.7%±8.6% \t FNR: 22.3%±13.8%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$104852
$$v(\\pi_e) \\geq v(\\pi_b)$$3341166
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.7\n", + "ESS: 662.5128750541377\n", + "RMSE: 0.045±0.014\n", + "Spearman: 0.949±0.016\n", + "Accuracy: 84.2%±6.9% \t FPR: 5.8%±8.7% \t FNR: 23.2%±14.8%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$103664
$$v(\\pi_e) \\geq v(\\pi_b)$$3481152
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.8\n", + "ESS: 662.5128750541377\n", + "RMSE: 0.049±0.015\n", + "Spearman: 0.940±0.020\n", + "Accuracy: 83.3%±7.1% \t FPR: 7.3%±10.4% \t FNR: 23.7%±15.7%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$102080
$$v(\\pi_e) \\geq v(\\pi_b)$$3551145
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.9\n", + "ESS: 662.5128750541377\n", + "RMSE: 0.053±0.016\n", + "Spearman: 0.930±0.023\n", + "Accuracy: 82.4%±7.7% \t FPR: 8.3%±11.5% \t FNR: 24.4%±16.8%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$100991
$$v(\\pi_e) \\geq v(\\pi_b)$$3661134
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 1.0\n", + "ESS: 662.5128750541377\n", + "RMSE: 0.057±0.018\n", + "Spearman: 0.919±0.029\n", + "Accuracy: 81.2%±7.7% \t FPR: 10.0%±13.0% \t FNR: 25.3%±17.8%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$990110
$$v(\\pi_e) \\geq v(\\pi_b)$$3791121
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "all_results_MissingImpute = []\n", + "for noise, df_results_lists in zip(noise_list, dfs_results_annotEvalNoiseMissingImpute_lists):\n", + " print('===', 'Noise', noise)\n", + " v_IS = []\n", + " v_WIS = []\n", + " v_ESS = []\n", + " for π_eval, df_results_ in zip(πs_flip_list, df_results_lists):\n", + " v_IS.append(df_results_['IS_value'])\n", + " v_WIS.append(df_results_['WIS_value'])\n", + " v_ESS.append(df_results_['ESS1'])\n", + " print('ESS:', np.mean(v_ESS))\n", + " \n", + " v_est_list = v_IS + v_WIS\n", + " rmse_value = [rmse(l, np.array(v_list+v_list)) for l in np.array(v_est_list).T]\n", + " print('RMSE: {:.3f}±{:.3f}'.format(np.mean(rmse_value).round(3), np.std(rmse_value).round(3)))\n", + " \n", + " spearman_corr = [scipy.stats.spearmanr(l, np.array(v_list+v_list)).correlation for l in np.array(v_est_list).T]\n", + " print('Spearman: {:.3f}±{:.3f}'.format(np.mean(spearman_corr).round(3), np.std(spearman_corr).round(3)))\n", + " \n", + " confmats_ = [compute_confmat_batch(v_list+v_list, l, vb) for l,vb in zip(np.array(v_est_list).T, df_results_0['IS_value'])]\n", + " confmat_sum = sum(confmats_)\n", + " (accuracy, fpr, fnr) = (\n", + " [(cm[0,0]+cm[1,1])/np.sum(cm) for cm in confmats_],\n", + " [cm[0,1]/(cm[0,0]+cm[0,1]) for cm in confmats_], \n", + " [cm[1,0]/(cm[1,0]+cm[1,1]) for cm in confmats_],\n", + " )\n", + " print('Accuracy: {:.1%}±{:.1%} \\t FPR: {:.1%}±{:.1%} \\t FNR: {:.1%}±{:.1%}'.format(\n", + " np.mean(accuracy), np.std(accuracy), \n", + " np.mean(fpr), np.std(fpr), \n", + " np.mean(fnr), np.std(fnr), \n", + " ))\n", + " display(pd.DataFrame(\n", + " (confmat_sum).astype(int), \n", + " index=['$$v(\\pi_e) < v(\\pi_b)$$', '$$v(\\pi_e) \\geq v(\\pi_b)$$'],\n", + " columns=['$$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$', '$$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$'],\n", + " ).rename_axis(index='True Ranking', columns=f'{name} Predicted Ranking')\\\n", + " .style.background_gradient(cmap='Blues', vmin=0, vmax=1500))\n", + " \n", + " all_results_MissingImpute.append({\n", + " 'spearman': spearman_corr,\n", + " 'rmse': rmse_value,\n", + " 'accuracy': accuracy,\n", + " 'fpr': fpr,\n", + " 'fnr': fnr\n", + " })" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "0fd6db75-a6d8-4ea6-8afe-ac6a4310e7cb", + "metadata": {}, + "outputs": [], + "source": [ + "df_all_results_Missing = pd.DataFrame(all_results_Missing, index=noise_list)\n", + "df_all_results_Missing.index.name = 'noise'\n", + "df_all_results_MissingImpute = pd.DataFrame(all_results_MissingImpute, index=noise_list)\n", + "df_all_results_MissingImpute.index.name = 'noise'" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "f3a5335f-7d22-4bbb-bfd8-e76489ebc7b4", + "metadata": {}, + "outputs": [], + "source": [ + "df_all_results_Missing_median = df_all_results_Missing.applymap(np.median)\n", + "df_all_results_Missing_mean = df_all_results_Missing.applymap(np.mean)\n", + "df_all_results_Missing_std = df_all_results_Missing.applymap(np.std)\n", + "df_all_results_MissingImpute_median = df_all_results_MissingImpute.applymap(np.median)\n", + "df_all_results_MissingImpute_mean = df_all_results_MissingImpute.applymap(np.mean)\n", + "df_all_results_MissingImpute_std = df_all_results_MissingImpute.applymap(np.std)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "3aaee0f8-c7f1-4234-b3e5-5410e74f58d3", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-05-05T15:35:28.521637\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(3,1, sharex=True, figsize=(2.5,3))\n", + "df_all_results_Missing_mean['rmse'].plot(ls='-', c='tab:blue', lw=1, ax=ax[0])\n", + "df_all_results_Missing_mean['spearman'].plot(ls='-', c='tab:blue', lw=1, ax=ax[1])\n", + "df_all_results_Missing_mean['accuracy'].plot(ls='-', c='tab:blue', lw=1, ax=ax[2])\n", + "ax[0].fill_between(noise_list, \n", + " df_all_results_Missing_mean['rmse']-df_all_results_Missing_std['rmse'], \n", + " df_all_results_Missing_mean['rmse']+df_all_results_Missing_std['rmse'], fc='tab:blue', alpha=0.2,\n", + " )\n", + "ax[1].fill_between(noise_list, \n", + " df_all_results_Missing_mean['spearman']-df_all_results_Missing_std['spearman'], \n", + " df_all_results_Missing_mean['spearman']+df_all_results_Missing_std['spearman'], fc='tab:blue', alpha=0.2,\n", + " )\n", + "ax[2].fill_between(noise_list, \n", + " df_all_results_Missing_mean['accuracy']-df_all_results_Missing_std['accuracy'], \n", + " df_all_results_Missing_mean['accuracy']+df_all_results_Missing_std['accuracy'], fc='tab:blue', alpha=0.2,\n", + " )\n", + "ax[0].plot(0, oracle_rmse_value, marker='*', c='k')\n", + "ax[1].plot(0, oracle_spearman_corr, marker='*', c='k')\n", + "ax[2].plot(0, 0.957, marker='*', c='k')\n", + "ax[0].axhline(orig_rmse_value, c='gray', ls=':')\n", + "ax[1].axhline(orig_spearman_corr, c='gray', ls=':')\n", + "ax[2].axhline(0.765, c='gray', ls=':')\n", + "ax[0].set_ylabel('RMSE')\n", + "ax[1].set_ylabel('Spearman\\nCorr.')\n", + "ax[2].set_ylabel('Bin. Class.\\nAccuracy')\n", + "ax[0].set_ylim(0, 0.15)\n", + "ax[1].set_ylim(0.45, 1.05)\n", + "ax[2].set_ylim(0.7, 1.05)\n", + "plt.xlabel('Std of annotation noise')\n", + "fig.align_labels()\n", + "plt.figlegend(\n", + " [matplotlib.lines.Line2D([], [], c='gray', ls=':'), \n", + " matplotlib.lines.Line2D([], [], c='k', marker='*', linestyle='none'),\n", + " matplotlib.lines.Line2D([], [], c='tab:blue'), \n", + " ],\n", + " ['baseline', 'ideal setting', 'proposed', ],\n", + " loc='center', bbox_to_anchor=(0.5,1.0),\n", + " ncol=2, handlelength=1.333, handletextpad=0.6, columnspacing=1,\n", + ")\n", + "plt.savefig('fig/annotation_noisy_missing10.pdf', bbox_inches='tight')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02f5e299-2124-4f79-a844-e20f5953d585", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sepsisSim/experiments/plots-noisy.ipynb b/sepsisSim/experiments/plots-noisy.ipynb new file mode 100644 index 0000000..7c671aa --- /dev/null +++ b/sepsisSim/experiments/plots-noisy.ipynb @@ -0,0 +1,5604 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "89b14dad-9e23-4f54-8bdf-dd5eebd8c025", + "metadata": {}, + "outputs": [], + "source": [ + "# ## Simulation parameters\n", + "exp_name = 'exp-FINAL'\n", + "eps = 0.10\n", + "eps_str = '0_1'\n", + "\n", + "run_idx_length = 1_000\n", + "N_val = 1_000\n", + "runs = 50" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a510145e-f9df-4460-a15d-377f2b4a6072", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib\n", + "matplotlib.rcParams['font.sans-serif'] = ['Arial']\n", + "%config InlineBackend.figure_formats = ['svg']" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "11339963-17b7-44f7-9823-59ac511a4a8a", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "from collections import defaultdict\n", + "import pickle\n", + "import itertools\n", + "import copy\n", + "import random\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import scipy.stats\n", + "from sklearn import metrics\n", + "import itertools\n", + "\n", + "import joblib\n", + "from joblib import Parallel, delayed" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0d61671d-b32e-4ec5-8859-85a5fae5c3c4", + "metadata": {}, + "outputs": [], + "source": [ + "from OPE_utils_new import (\n", + " format_data_tensor,\n", + " policy_eval_analytic_finite,\n", + " OPE_IS_h,\n", + " compute_behavior_policy_h,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0c573be0-9afb-4dc0-b018-4c62a0ce15f7", + "metadata": {}, + "outputs": [], + "source": [ + "def policy_eval_helper(π):\n", + " V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H)\n", + " Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R]\n", + " J = isd @ V_H[0]\n", + " # Check recursive relationships\n", + " assert len(Q_H) == H\n", + " assert len(V_H) == H\n", + " assert np.all(Q_H[-1] == R)\n", + " assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1])\n", + " assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2])\n", + " return V_H, Q_H, J" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a0ab1c7c-bcd3-43ad-848b-93268fa1a2b6", + "metadata": {}, + "outputs": [], + "source": [ + "def iqm(x):\n", + " return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "467e2e02-2f18-4fa9-9d17-3b9ae0e4b8dd", + "metadata": {}, + "outputs": [], + "source": [ + "NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP\n", + "G_min = -1 # the minimum possible return\n", + "G_max = 1 # the maximum possible return\n", + "nS, nA = 1442, 8\n", + "\n", + "PROB_DIAB = 0.2" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fdf1dfc0-7d33-4f7d-952c-5b77668f2968", + "metadata": {}, + "outputs": [], + "source": [ + "# Ground truth MDP model\n", + "MDP_parameters = joblib.load('../data/MDP_parameters.joblib')\n", + "P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next)\n", + "R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A)\n", + "nS, nA = R.shape\n", + "gamma = 0.99\n", + "\n", + "# unif rand isd, mixture of diabetic state\n", + "isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib')\n", + "isd = (isd > 0).astype(float)\n", + "isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB)\n", + "isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c210ea9b-a6e0-42e6-be21-35a8e439c0c8", + "metadata": {}, + "outputs": [], + "source": [ + "# Precomputed optimal policy\n", + "π_star = joblib.load('../data/π_star.joblib')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "9342ffbd-00f1-4916-b418-054a289f20bd", + "metadata": { + "tags": [] + }, + "source": [ + "## Policies" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "645c4315-1109-4b41-b7ec-177e03e3343d", + "metadata": {}, + "outputs": [], + "source": [ + "# vaso unif, mv abx optimal\n", + "π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "32ece6ab-6dca-479a-ab85-0cc3233e9f72", + "metadata": {}, + "source": [ + "### Behavior policy" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "639c3b25-c715-4a90-8d1a-094e4478b9c9", + "metadata": {}, + "outputs": [], + "source": [ + "# vaso eps=0.5, mv abx optimal\n", + "π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)\n", + "π_beh[π_star == 1] = 1-eps\n", + "π_beh[π_beh == 0.5] = eps" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "50a792f2-1f9d-40f7-89db-bd9245eabdf3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.25038354793851164" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh)\n", + "J_beh" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "8d88637a-8298-4f1f-88ef-ab11b05ec08b", + "metadata": {}, + "source": [ + "### Optimal policy" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "d849f76e-01a1-48f8-9d8a-e9d0a880cadf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.40877179296760224" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V_H_star, Q_H_star, J_star = policy_eval_helper(π_star)\n", + "J_star" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "44136cb8-d2e0-459d-8ee1-f15b2ed8a62c", + "metadata": {}, + "source": [ + "### flip action for x% states" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "97e67e52-09c6-4f0d-b435-1304a88a98d6", + "metadata": {}, + "outputs": [], + "source": [ + "def get_π_flip(pol_flip_seed, pol_flip_num):\n", + " rng_flip = np.random.default_rng(pol_flip_seed)\n", + " flip_states = rng_flip.choice(range(1440), pol_flip_num, replace=False)\n", + "\n", + " π_tmp = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)\n", + " π_flip = π_tmp.copy()\n", + " π_flip[π_tmp == 0.5] = 0\n", + " π_flip[π_star == 1] = 1\n", + " for s in flip_states:\n", + " π_flip[s, π_tmp[s] == 0.5] = 1\n", + " π_flip[s, π_star[s] == 1] = 0\n", + " assert π_flip.sum(axis=1).mean() == 1\n", + " return π_flip" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "67c8441f-07bc-425a-ad5b-26b970653986", + "metadata": {}, + "outputs": [], + "source": [ + "πs_flip_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " π_flip = get_π_flip(flip_seed, flip_num)\n", + " πs_flip_list.append(π_flip)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "976a7213-6cb1-4511-a888-80173c967581", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████████████████████████████████████████████████████████████████████████████████| 26/26 [00:16<00:00, 1.57it/s]\n" + ] + } + ], + "source": [ + "v_list = []\n", + "for π_eval in tqdm(πs_flip_list):\n", + " _, _, J_eval = policy_eval_helper(π_eval)\n", + " v_list.append(J_eval)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "c421cb40-2069-4d04-be56-2456dc6810b4", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-29T12:30:46.010162\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(v_list, ls='none', marker='.')\n", + "plt.axhline(J_beh, c='k', ls='--', label='behavior', zorder=0)\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "62b4937e-484a-4740-90ea-8cf91fd19578", + "metadata": {}, + "source": [ + "## Load results" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "73d56e0a-de4a-4a90-9c3a-c4b78a02e64b", + "metadata": {}, + "outputs": [], + "source": [ + "df_results_0 = pd.read_csv(f'./results/{exp_name}-1/vaso_eps_{eps_str}-observed.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "d23aed3c-898f-4c44-b32b-20f744ab3a7d", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_orig_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-1/vaso_eps_{eps_str}-{pol_name}-orig.csv')\n", + " dfs_results_orig_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "074e313e-ccc4-475f-857b-12b83da8b48d", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotEval_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-2/vaso_eps_{eps_str}-{pol_name}-aug_step-annotEval.csv')\n", + " dfs_results_annotEval_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 170, + "id": "0ec98102-a6c2-4521-a6ba-6267d7369cdf", + "metadata": {}, + "outputs": [], + "source": [ + "noise_list = list(np.arange(0.0, 1.1, 0.1).round(1))" + ] + }, + { + "cell_type": "code", + "execution_count": 171, + "id": "31b214a1-8930-422a-8b81-774425fa7278", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotEvalNoise_lists = []\n", + "for noise in noise_list:\n", + " dfs_results_annotEvalNoise_list = []\n", + " for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-3/vaso_eps_{eps_str}-{pol_name}-aug_step-annotEval-Noise_{noise}.csv')\n", + " dfs_results_annotEvalNoise_list.append(df_results_)\n", + " dfs_results_annotEvalNoise_lists.append(dfs_results_annotEvalNoise_list)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "1f8bff6a-1a52-4b51-8d6e-89c8c6ffd5c6", + "metadata": {}, + "source": [ + "## Plots" + ] + }, + { + "cell_type": "code", + "execution_count": 172, + "id": "a0802fdd-7fe9-405b-9a13-2557a2065a11", + "metadata": {}, + "outputs": [], + "source": [ + "exp_idx = 13\n", + "π_eval = πs_flip_list[exp_idx]\n", + "J_eval = v_list[exp_idx]\n", + "df_results_orig = dfs_results_orig_list[exp_idx]\n", + "df_results_annotEval = dfs_results_annotEval_list[exp_idx]" + ] + }, + { + "cell_type": "code", + "execution_count": 173, + "id": "a891204f-974c-4be3-8eb6-4c9446b67b7f", + "metadata": {}, + "outputs": [], + "source": [ + "df_results_annotEval_Noise = [dfs[exp_idx] for dfs in dfs_results_annotEvalNoise_lists]" + ] + }, + { + "cell_type": "code", + "execution_count": 174, + "id": "652094bf-2cec-4f04-8c65-39d62b3a5831", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-29T16:22:38.978369\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(4,4))\n", + "plt.axhline(J_beh, c='k', ls='--', label='behavior', zorder=0)\n", + "plt.scatter(N_val, J_beh, marker='o', s=40, c='k', alpha=1)\n", + "plt.axhline(J_eval, c='k', ls=':', label='true', zorder=0)\n", + "plt.scatter([N_val]*runs, df_results_0['IS_value'], \n", + " marker='o', s=10, c='k', alpha=0.25, ec='none')\n", + "\n", + "for df, name, color in [\n", + " [df_results_orig, 'original', 'tab:blue'],\n", + " [df_results_annotEval, 'proposed', 'tab:green'],\n", + " [df_results_annotEval_Noise[-1], 'proposed noise', 'tab:cyan'],\n", + "]:\n", + " plt.plot(0,0, c=color, label=name)\n", + " plt.scatter(df['ESS1'], df['IS_value'], marker='o', s=10, c=color, alpha=0.25, ec='none')\n", + " plt.scatter(iqm(df['ESS1']), iqm(df['IS_value']), marker='o', s=40, c=color, alpha=0.8)\n", + " plt.scatter(df['ESS1'], df['WIS_value'], marker='X', s=10, c=color, alpha=0.25, ec='none')\n", + " plt.scatter(iqm(df['ESS1']), iqm(df['WIS_value']), marker='X', s=40, c=color, alpha=0.8)\n", + "\n", + "plt.scatter(-100,0, c='gray', marker='o', label='IS')\n", + "plt.scatter(-100,0, c='gray', marker='X', label='WIS')\n", + "plt.xlabel('ESS')\n", + "plt.ylabel('OPE value')\n", + "# plt.ylim(0.1, 0.7)\n", + "plt.xlim(0, N_val*1.25)\n", + "plt.legend(bbox_to_anchor=(1.04, 1))\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "da761092-92b5-439c-aa25-9efd0ddb674a", + "metadata": {}, + "source": [ + "## Metrics" + ] + }, + { + "cell_type": "code", + "execution_count": 175, + "id": "980a4d9d-038e-43e0-a19b-6d5cd8d4c1c1", + "metadata": {}, + "outputs": [], + "source": [ + "def rmse(y1, y2):\n", + " return np.sqrt(np.mean(np.square(y1-y2)))" + ] + }, + { + "cell_type": "code", + "execution_count": 176, + "id": "949d31e2-c9c1-4f22-8169-4e8befa02ed2", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_confmat(J_eval, v_est_list):\n", + " confmat = np.zeros((2,2))\n", + " confmat[int(J_eval >= J_beh), 0] += (v_est_list < df_results_0['IS_value'].iloc[:len(v_est_list)]).sum()\n", + " confmat[int(J_eval >= J_beh), 1] += (v_est_list >= df_results_0['IS_value'].iloc[:len(v_est_list)]).sum()\n", + " return confmat" + ] + }, + { + "cell_type": "code", + "execution_count": 275, + "id": "fd117d72-4840-4510-a2d9-9013ef4a27ee", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_confmat_batch(v_list, v_est_list, v_beh):\n", + " confmat = np.zeros((2,2))\n", + " for J_eval, v_est in zip(v_list, v_est_list):\n", + " confmat[int(J_eval >= J_beh), 0] += (v_est < v_beh).sum()\n", + " confmat[int(J_eval >= J_beh), 1] += (v_est >= v_beh).sum()\n", + " return confmat" + ] + }, + { + "cell_type": "code", + "execution_count": 177, + "id": "ffd8f027-3605-4876-b6e6-4611ca95f1d3", + "metadata": {}, + "outputs": [], + "source": [ + "orig_v_IS = []\n", + "orig_v_WIS = []\n", + "orig_v_ESS = []\n", + "for π_eval, df_results_ in zip(πs_flip_list, dfs_results_orig_list):\n", + " orig_v_IS.append(df_results_['IS_value'])\n", + " orig_v_WIS.append(df_results_['WIS_value'])\n", + " orig_v_ESS.append(df_results_['ESS1'])" + ] + }, + { + "cell_type": "code", + "execution_count": 178, + "id": "3b167971-fe17-49d5-9b0b-a137b40f2531", + "metadata": {}, + "outputs": [], + "source": [ + "annotEval_v_IS = []\n", + "annotEval_v_WIS = []\n", + "annotEval_v_ESS = []\n", + "for π_eval, df_results_ in zip(πs_flip_list, dfs_results_annotEval_list):\n", + " annotEval_v_IS.append(df_results_['IS_value'])\n", + " annotEval_v_WIS.append(df_results_['WIS_value'])\n", + " annotEval_v_ESS.append(df_results_['ESS1'])" + ] + }, + { + "cell_type": "code", + "execution_count": 179, + "id": "2f553302-d7d2-4882-a030-d05ff0b75068", + "metadata": {}, + "outputs": [], + "source": [ + "orig_rmse_value = np.mean([rmse(l,v) for v,l in zip(v_list+v_list, orig_v_IS+orig_v_WIS)])\n", + "orig_spearman_corr = np.mean([scipy.stats.spearmanr(v_list+v_list, v_π_est_list).correlation for v_π_est_list in np.array(orig_v_IS+orig_v_WIS).T])" + ] + }, + { + "cell_type": "code", + "execution_count": 180, + "id": "d507ed94-0cca-442f-802c-bf79707fd2b3", + "metadata": {}, + "outputs": [], + "source": [ + "oracle_rmse_value = np.mean([rmse(l,v) for v,l in zip(v_list+v_list, annotEval_v_IS+annotEval_v_WIS)])\n", + "oracle_spearman_corr = np.mean([scipy.stats.spearmanr(v_list+v_list, v_π_est_list).correlation for v_π_est_list in np.array(annotEval_v_IS+annotEval_v_WIS).T])" + ] + }, + { + "cell_type": "code", + "execution_count": 277, + "id": "3a44e898-b5d2-4f56-bda6-ceec6f90987a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.0\n", + "ESS: 993.9926736346977 3.030694969543688\n", + "RMSE: 0.013±0.005\n", + "Spearman: 0.995±0.003\n", + "Accuracy: 95.7%±3.1% \t FPR: 4.5%±6.9% \t FNR: 4.2%±5.3%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$105050
$$v(\\pi_e) \\geq v(\\pi_b)$$631437
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.1\n", + "ESS: 993.9926736346977 3.030694969543688\n", + "RMSE: 0.013±0.005\n", + "Spearman: 0.995±0.003\n", + "Accuracy: 95.7%±3.5% \t FPR: 4.4%±7.5% \t FNR: 4.3%±5.7%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$105248
$$v(\\pi_e) \\geq v(\\pi_b)$$651435
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.2\n", + "ESS: 993.9926736346977 3.030694969543688\n", + "RMSE: 0.014±0.005\n", + "Spearman: 0.994±0.002\n", + "Accuracy: 95.3%±4.1% \t FPR: 4.4%±7.9% \t FNR: 4.9%±6.8%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$105248
$$v(\\pi_e) \\geq v(\\pi_b)$$731427
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.3\n", + "ESS: 993.9926736346977 3.030694969543688\n", + "RMSE: 0.015±0.005\n", + "Spearman: 0.994±0.002\n", + "Accuracy: 95.2%±4.5% \t FPR: 4.7%±8.1% \t FNR: 4.9%±7.5%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$104852
$$v(\\pi_e) \\geq v(\\pi_b)$$741426
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.4\n", + "ESS: 993.9926736346977 3.030694969543688\n", + "RMSE: 0.016±0.006\n", + "Spearman: 0.992±0.003\n", + "Accuracy: 94.9%±4.6% \t FPR: 4.5%±8.2% \t FNR: 5.5%±8.0%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$105050
$$v(\\pi_e) \\geq v(\\pi_b)$$821418
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.5\n", + "ESS: 993.9926736346977 3.030694969543688\n", + "RMSE: 0.018±0.006\n", + "Spearman: 0.991±0.004\n", + "Accuracy: 94.8%±4.8% \t FPR: 4.5%±8.2% \t FNR: 5.7%±8.3%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$105050
$$v(\\pi_e) \\geq v(\\pi_b)$$851415
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.6\n", + "ESS: 993.9926736346977 3.030694969543688\n", + "RMSE: 0.020±0.007\n", + "Spearman: 0.988±0.005\n", + "Accuracy: 94.2%±5.1% \t FPR: 4.9%±8.8% \t FNR: 6.5%±8.9%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$104654
$$v(\\pi_e) \\geq v(\\pi_b)$$981402
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.7\n", + "ESS: 993.9926736346977 3.030694969543688\n", + "RMSE: 0.022±0.008\n", + "Spearman: 0.986±0.006\n", + "Accuracy: 93.5%±5.4% \t FPR: 5.2%±9.3% \t FNR: 7.5%±9.8%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$104357
$$v(\\pi_e) \\geq v(\\pi_b)$$1121388
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.8\n", + "ESS: 993.9926736346977 3.030694969543688\n", + "RMSE: 0.024±0.008\n", + "Spearman: 0.983±0.008\n", + "Accuracy: 92.9%±5.7% \t FPR: 5.7%±10.0% \t FNR: 8.1%±10.3%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$103763
$$v(\\pi_e) \\geq v(\\pi_b)$$1221378
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 0.9\n", + "ESS: 993.9926736346977 3.030694969543688\n", + "RMSE: 0.027±0.009\n", + "Spearman: 0.980±0.009\n", + "Accuracy: 92.4%±5.7% \t FPR: 6.5%±10.6% \t FNR: 8.3%±10.5%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$102872
$$v(\\pi_e) \\geq v(\\pi_b)$$1251375
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== Noise 1.0\n", + "ESS: 993.9926736346977 3.030694969543688\n", + "RMSE: 0.029±0.010\n", + "Spearman: 0.977±0.012\n", + "Accuracy: 91.9%±5.6% \t FPR: 6.7%±10.5% \t FNR: 9.1%±10.6%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
proposed noise Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$102674
$$v(\\pi_e) \\geq v(\\pi_b)$$1361364
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "all_results = []\n", + "for noise, df_results_lists in zip(noise_list, dfs_results_annotEvalNoise_lists):\n", + " print('===', 'Noise', noise)\n", + " v_IS = []\n", + " v_WIS = []\n", + " v_ESS = []\n", + " for π_eval, df_results_ in zip(πs_flip_list, df_results_lists):\n", + " v_IS.append(df_results_['IS_value'])\n", + " v_WIS.append(df_results_['WIS_value'])\n", + " v_ESS.append(df_results_['ESS1'])\n", + " print('ESS:', np.mean(v_ESS), np.array(v_ESS).mean(axis=0).std())\n", + " \n", + " v_est_list = v_IS + v_WIS\n", + " rmse_value = [rmse(l, np.array(v_list+v_list)) for l in np.array(v_est_list).T]\n", + " print('RMSE: {:.3f}±{:.3f}'.format(np.mean(rmse_value).round(3), np.std(rmse_value).round(3)))\n", + " \n", + " spearman_corr = [scipy.stats.spearmanr(l, np.array(v_list+v_list)).correlation for l in np.array(v_est_list).T]\n", + " print('Spearman: {:.3f}±{:.3f}'.format(np.mean(spearman_corr).round(3), np.std(spearman_corr).round(3)))\n", + " \n", + " confmats_ = [compute_confmat_batch(v_list+v_list, l, vb) for l,vb in zip(np.array(v_est_list).T, df_results_0['IS_value'])]\n", + " confmat_sum = sum(confmats_)\n", + " (accuracy, fpr, fnr) = (\n", + " [(cm[0,0]+cm[1,1])/np.sum(cm) for cm in confmats_],\n", + " [cm[0,1]/(cm[0,0]+cm[0,1]) for cm in confmats_], \n", + " [cm[1,0]/(cm[1,0]+cm[1,1]) for cm in confmats_],\n", + " )\n", + " print('Accuracy: {:.1%}±{:.1%} \\t FPR: {:.1%}±{:.1%} \\t FNR: {:.1%}±{:.1%}'.format(\n", + " np.mean(accuracy), np.std(accuracy), \n", + " np.mean(fpr), np.std(fpr), \n", + " np.mean(fnr), np.std(fnr), \n", + " ))\n", + " display(pd.DataFrame(\n", + " (confmat_sum).astype(int), \n", + " index=['$$v(\\pi_e) < v(\\pi_b)$$', '$$v(\\pi_e) \\geq v(\\pi_b)$$'],\n", + " columns=['$$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$', '$$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$'],\n", + " ).rename_axis(index='True Ranking', columns=f'{name} Predicted Ranking')\\\n", + " .style.background_gradient(cmap='Blues', vmin=0, vmax=1500))\n", + " \n", + " all_results.append({\n", + " 'spearman': spearman_corr,\n", + " 'rmse': rmse_value,\n", + " 'accuracy': accuracy,\n", + " 'fpr': fpr,\n", + " 'fnr': fnr\n", + " })" + ] + }, + { + "cell_type": "code", + "execution_count": 280, + "id": "44db5686-b452-49e2-b97e-77047616022b", + "metadata": {}, + "outputs": [], + "source": [ + "df_all_results = pd.DataFrame(all_results, index=noise_list)\n", + "df_all_results.index.name = 'noise'" + ] + }, + { + "cell_type": "code", + "execution_count": 292, + "id": "2298d338-30f5-4e81-81a8-00392a2d0fd4", + "metadata": {}, + "outputs": [], + "source": [ + "df_all_results_median = df_all_results.applymap(np.median)\n", + "df_all_results_mean = df_all_results.applymap(np.mean)\n", + "df_all_results_std = df_all_results.applymap(np.std)" + ] + }, + { + "cell_type": "code", + "execution_count": 306, + "id": "3aaee0f8-c7f1-4234-b3e5-5410e74f58d3", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-05-08T11:11:09.702775\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(3,1, sharex=True, figsize=(2.5,3))\n", + "df_all_results_mean['rmse'].plot(ls='-', c='tab:blue', lw=1, ax=ax[0])\n", + "df_all_results_mean['spearman'].plot(ls='-', c='tab:blue', lw=1, ax=ax[1])\n", + "df_all_results_mean['accuracy'].plot(ls='-', c='tab:blue', lw=1, ax=ax[2])\n", + "ax[0].fill_between(noise_list, \n", + " df_all_results_mean['rmse']-df_all_results_std['rmse'], \n", + " df_all_results_mean['rmse']+df_all_results_std['rmse'], fc='tab:blue', alpha=0.2,\n", + " )\n", + "ax[1].fill_between(noise_list, \n", + " df_all_results_mean['spearman']-df_all_results_std['spearman'], \n", + " df_all_results_mean['spearman']+df_all_results_std['spearman'], fc='tab:blue', alpha=0.2,\n", + " )\n", + "ax[2].fill_between(noise_list, \n", + " df_all_results_mean['accuracy']-df_all_results_std['accuracy'], \n", + " df_all_results_mean['accuracy']+df_all_results_std['accuracy'], fc='tab:blue', alpha=0.2,\n", + " )\n", + "ax[0].plot(0, oracle_rmse_value, marker='*', c='k')\n", + "ax[1].plot(0, oracle_spearman_corr, marker='*', c='k')\n", + "ax[2].plot(0, 0.957, marker='*', c='k')\n", + "ax[0].axhline(orig_rmse_value, c='gray', ls=':')\n", + "ax[1].axhline(orig_spearman_corr, c='gray', ls=':')\n", + "ax[2].axhline(0.765, c='gray', ls=':')\n", + "ax[0].set_ylabel('RMSE')\n", + "ax[1].set_ylabel('Spearman\\nCorr.')\n", + "ax[2].set_ylabel('Bin. Class.\\nAccuracy')\n", + "ax[0].set_ylim(0, 0.15)\n", + "ax[1].set_ylim(0.45, 1.05)\n", + "ax[2].set_ylim(0.7, 1.05)\n", + "plt.xlabel('Std of annotation noise')\n", + "fig.align_labels()\n", + "# plt.figlegend(\n", + "# [matplotlib.lines.Line2D([], [], c='gray', ls=':'), \n", + "# matplotlib.lines.Line2D([], [], c='k', marker='*', linestyle='none'),\n", + "# matplotlib.lines.Line2D([], [], c='tab:blue'), \n", + "# ],\n", + "# ['baseline', 'ideal setting', 'proposed', ],\n", + "# loc='center', bbox_to_anchor=(0.5,1.0),\n", + "# ncol=2, handlelength=1.333, handletextpad=0.6, columnspacing=1,\n", + "# )\n", + "plt.savefig('fig/annotation_noisy.pdf', bbox_inches='tight')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02f5e299-2124-4f79-a844-e20f5953d585", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sepsisSim/experiments/results-OIS-WIS.ipynb b/sepsisSim/experiments/results-OIS-WIS.ipynb new file mode 100644 index 0000000..aec3622 --- /dev/null +++ b/sepsisSim/experiments/results-OIS-WIS.ipynb @@ -0,0 +1,6641 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "89b14dad-9e23-4f54-8bdf-dd5eebd8c025", + "metadata": {}, + "outputs": [], + "source": [ + "# ## Simulation parameters\n", + "exp_name = 'exp-FINAL'\n", + "eps = 0.10\n", + "eps_str = '0_1'\n", + "\n", + "run_idx_length = 1_000\n", + "N_val = 1_000\n", + "runs = 50" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a510145e-f9df-4460-a15d-377f2b4a6072", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib\n", + "matplotlib.rcParams['font.sans-serif'] = ['Arial']\n", + "%config InlineBackend.figure_formats = ['svg']" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "11339963-17b7-44f7-9823-59ac511a4a8a", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "from collections import defaultdict\n", + "import pickle\n", + "import itertools\n", + "import copy\n", + "import random\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import scipy.stats\n", + "from sklearn import metrics\n", + "import itertools\n", + "\n", + "import joblib\n", + "from joblib import Parallel, delayed" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0d61671d-b32e-4ec5-8859-85a5fae5c3c4", + "metadata": {}, + "outputs": [], + "source": [ + "from OPE_utils_new import (\n", + " format_data_tensor,\n", + " policy_eval_analytic_finite,\n", + " OPE_IS_h,\n", + " compute_behavior_policy_h,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0c573be0-9afb-4dc0-b018-4c62a0ce15f7", + "metadata": {}, + "outputs": [], + "source": [ + "def policy_eval_helper(π):\n", + " V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H)\n", + " Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R]\n", + " J = isd @ V_H[0]\n", + " # Check recursive relationships\n", + " assert len(Q_H) == H\n", + " assert len(V_H) == H\n", + " assert np.all(Q_H[-1] == R)\n", + " assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1])\n", + " assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2])\n", + " return V_H, Q_H, J" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a0ab1c7c-bcd3-43ad-848b-93268fa1a2b6", + "metadata": {}, + "outputs": [], + "source": [ + "def iqm(x):\n", + " return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "467e2e02-2f18-4fa9-9d17-3b9ae0e4b8dd", + "metadata": {}, + "outputs": [], + "source": [ + "NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP\n", + "G_min = -1 # the minimum possible return\n", + "G_max = 1 # the maximum possible return\n", + "nS, nA = 1442, 8\n", + "\n", + "PROB_DIAB = 0.2" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fdf1dfc0-7d33-4f7d-952c-5b77668f2968", + "metadata": {}, + "outputs": [], + "source": [ + "# Ground truth MDP model\n", + "MDP_parameters = joblib.load('../data/MDP_parameters.joblib')\n", + "P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next)\n", + "R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A)\n", + "nS, nA = R.shape\n", + "gamma = 0.99\n", + "\n", + "# unif rand isd, mixture of diabetic state\n", + "isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib')\n", + "isd = (isd > 0).astype(float)\n", + "isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB)\n", + "isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c210ea9b-a6e0-42e6-be21-35a8e439c0c8", + "metadata": {}, + "outputs": [], + "source": [ + "# Precomputed optimal policy\n", + "π_star = joblib.load('../data/π_star.joblib')" + ] + }, + { + "cell_type": "markdown", + "id": "9342ffbd-00f1-4916-b418-054a289f20bd", + "metadata": { + "tags": [] + }, + "source": [ + "## Policies" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "645c4315-1109-4b41-b7ec-177e03e3343d", + "metadata": {}, + "outputs": [], + "source": [ + "# vaso unif, mv abx optimal\n", + "π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)" + ] + }, + { + "cell_type": "markdown", + "id": "32ece6ab-6dca-479a-ab85-0cc3233e9f72", + "metadata": {}, + "source": [ + "### Behavior policy" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "639c3b25-c715-4a90-8d1a-094e4478b9c9", + "metadata": {}, + "outputs": [], + "source": [ + "# vaso eps=0.5, mv abx optimal\n", + "π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)\n", + "π_beh[π_star == 1] = 1-eps\n", + "π_beh[π_beh == 0.5] = eps" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "50a792f2-1f9d-40f7-89db-bd9245eabdf3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.25038354793851164" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh)\n", + "J_beh" + ] + }, + { + "cell_type": "markdown", + "id": "8d88637a-8298-4f1f-88ef-ab11b05ec08b", + "metadata": {}, + "source": [ + "### Optimal policy" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "d849f76e-01a1-48f8-9d8a-e9d0a880cadf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.40877179296760224" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V_H_star, Q_H_star, J_star = policy_eval_helper(π_star)\n", + "J_star" + ] + }, + { + "cell_type": "markdown", + "id": "44136cb8-d2e0-459d-8ee1-f15b2ed8a62c", + "metadata": {}, + "source": [ + "### flip action for x% states" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "97e67e52-09c6-4f0d-b435-1304a88a98d6", + "metadata": {}, + "outputs": [], + "source": [ + "def get_π_flip(pol_flip_seed, pol_flip_num):\n", + " rng_flip = np.random.default_rng(pol_flip_seed)\n", + " flip_states = rng_flip.choice(range(1440), pol_flip_num, replace=False)\n", + "\n", + " π_tmp = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)\n", + " π_flip = π_tmp.copy()\n", + " π_flip[π_tmp == 0.5] = 0\n", + " π_flip[π_star == 1] = 1\n", + " for s in flip_states:\n", + " π_flip[s, π_tmp[s] == 0.5] = 1\n", + " π_flip[s, π_star[s] == 1] = 0\n", + " assert π_flip.sum(axis=1).mean() == 1\n", + " return π_flip" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "67c8441f-07bc-425a-ad5b-26b970653986", + "metadata": {}, + "outputs": [], + "source": [ + "πs_flip_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " π_flip = get_π_flip(flip_seed, flip_num)\n", + " πs_flip_list.append(π_flip)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "976a7213-6cb1-4511-a888-80173c967581", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████████████████████████████████████████████████████████████████████████████████| 26/26 [00:24<00:00, 1.07it/s]\n" + ] + } + ], + "source": [ + "v_list = []\n", + "for π_eval in tqdm(πs_flip_list):\n", + " _, _, J_eval = policy_eval_helper(π_eval)\n", + " v_list.append(J_eval)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "c421cb40-2069-4d04-be56-2456dc6810b4", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-05-21T12:01:47.984107\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(4,3))\n", + "plt.plot(v_list, ls='none', marker='.')\n", + "plt.axhline(J_beh, c='k', ls='--', label='behavior', zorder=0)\n", + "plt.xlabel('policy index')\n", + "plt.ylabel('policy value')\n", + "# plt.savefig('fig/sepsisSim-policies.pdf', bbox_inches='tight')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "62b4937e-484a-4740-90ea-8cf91fd19578", + "metadata": {}, + "source": [ + "## Load results" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "73d56e0a-de4a-4a90-9c3a-c4b78a02e64b", + "metadata": {}, + "outputs": [], + "source": [ + "df_results_0 = pd.read_csv(f'./results/{exp_name}-1/vaso_eps_{eps_str}-observed.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "1d62f73d-34b2-42c4-96e8-5992765189a6", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_Naive_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-1/vaso_eps_{eps_str}-{pol_name}-aug_step-Naive.csv')\n", + " dfs_results_Naive_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "6a3bbe16-6ad9-4131-b42f-e3d8a9525851", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_NaiveUW_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-1/vaso_eps_{eps_str}-{pol_name}-aug_step-NaiveUW.csv')\n", + " dfs_results_NaiveUW_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "d23aed3c-898f-4c44-b32b-20f744ab3a7d", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_orig_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-1/vaso_eps_{eps_str}-{pol_name}-orig.csv')\n", + " dfs_results_orig_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "074e313e-ccc4-475f-857b-12b83da8b48d", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotEval_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-2/vaso_eps_{eps_str}-{pol_name}-aug_step-annotEval.csv')\n", + " dfs_results_annotEval_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "6cb0164e-1c33-486c-ad59-ad1ebf9a2e01", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotBeh_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-2/vaso_eps_{eps_str}-{pol_name}-aug_step-annotBeh.csv')\n", + " dfs_results_annotBeh_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "46613104-2d46-4110-816e-174dd9160a56", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotZero_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-2/vaso_eps_{eps_str}-{pol_name}-aug_step-annotZero.csv')\n", + " dfs_results_annotZero_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "2796297e-d257-4288-a73d-46999b565983", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotBehConvertedAM_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-5/vaso_eps_{eps_str}-{pol_name}-aug_step-annotBehConvertedAM.csv')\n", + " dfs_results_annotBehConvertedAM_list.append(df_results_)" + ] + }, + { + "cell_type": "markdown", + "id": "1f8bff6a-1a52-4b51-8d6e-89c8c6ffd5c6", + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, + "source": [ + "## Plots" + ] + }, + { + "cell_type": "code", + "execution_count": 164, + "id": "a0802fdd-7fe9-405b-9a13-2557a2065a11", + "metadata": {}, + "outputs": [], + "source": [ + "exp_idx = 0\n", + "π_eval = πs_flip_list[exp_idx]\n", + "J_eval = v_list[exp_idx]\n", + "df_results_orig = dfs_results_orig_list[exp_idx]\n", + "df_results_Naive = dfs_results_Naive_list[exp_idx]\n", + "df_results_NaiveUW = dfs_results_NaiveUW_list[exp_idx]\n", + "df_results_annotEval = dfs_results_annotEval_list[exp_idx]\n", + "df_results_annotBeh = dfs_results_annotBeh_list[exp_idx]\n", + "df_results_annotZero = dfs_results_annotZero_list[exp_idx]\n", + "df_results_annotBehConvertedAM = dfs_results_annotBehConvertedAM_list[exp_idx]" + ] + }, + { + "cell_type": "code", + "execution_count": 166, + "id": "652094bf-2cec-4f04-8c65-39d62b3a5831", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-05-08T09:58:25.277753\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(4,4))\n", + "plt.axhline(J_beh, c='k', ls='--', label='behavior', zorder=0)\n", + "plt.scatter(N_val, J_beh, marker='o', s=40, c='k', alpha=1)\n", + "plt.axhline(J_eval, c='k', ls=':', label='true', zorder=0)\n", + "plt.scatter([N_val]*runs, df_results_0['IS_value'], \n", + " marker='o', s=10, c='k', alpha=0.25, ec='none')\n", + "\n", + "for df, name, color in [\n", + " [df_results_orig, 'original', 'tab:blue'],\n", + " [df_results_Naive, 'naive w', 'tab:cyan'],\n", + " [df_results_NaiveUW, 'naive uw', 'tab:olive'],\n", + " [df_results_annotEval, 'prop. eval', 'tab:green'],\n", + " [df_results_annotBeh, 'prop. beh', 'tab:orange'],\n", + " [df_results_annotZero, 'prop. zero', 'tab:pink'],\n", + " [df_results_annotBehConvertedAM, 'prop. beh conv', 'gold'],\n", + "]:\n", + " plt.plot(0,0, c=color, label=name)\n", + " plt.scatter(df['ESS1'], df['IS_value'], marker='o', s=10, c=color, alpha=0.25, ec='none')\n", + " plt.scatter(iqm(df['ESS1']), iqm(df['IS_value']), marker='o', s=40, c=color, alpha=0.8)\n", + " plt.scatter(df['ESS1'], df['WIS_value'], marker='X', s=10, c=color, alpha=0.25, ec='none')\n", + " plt.scatter(iqm(df['ESS1']), iqm(df['WIS_value']), marker='X', s=40, c=color, alpha=0.8)\n", + "\n", + "plt.scatter(-100,0, c='gray', marker='o', label='IS')\n", + "plt.scatter(-100,0, c='gray', marker='X', label='WIS')\n", + "plt.xlabel('ESS')\n", + "plt.ylabel('OPE value')\n", + "# plt.ylim(0.1, 0.7)\n", + "plt.xlim(0, N_val*1.25)\n", + "plt.legend(bbox_to_anchor=(1.04, 1))\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "da761092-92b5-439c-aa25-9efd0ddb674a", + "metadata": {}, + "source": [ + "## Metrics" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "980a4d9d-038e-43e0-a19b-6d5cd8d4c1c1", + "metadata": {}, + "outputs": [], + "source": [ + "def rmse(y1, y2):\n", + " return np.sqrt(np.mean(np.square(y1-y2)))" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "949d31e2-c9c1-4f22-8169-4e8befa02ed2", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_confmat(J_eval, v_est_list):\n", + " confmat = np.zeros((2,2))\n", + " confmat[int(J_eval >= J_beh), 0] += (v_est_list < df_results_0['IS_value'].iloc[:len(v_est_list)]).sum()\n", + " confmat[int(J_eval >= J_beh), 1] += (v_est_list >= df_results_0['IS_value'].iloc[:len(v_est_list)]).sum()\n", + " return confmat" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "84027389-4866-4c64-a882-ec9448a5c71d", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_confmat_batch(v_list, v_est_list, v_beh):\n", + " confmat = np.zeros((2,2))\n", + " for J_eval, v_est in zip(v_list, v_est_list):\n", + " confmat[int(J_eval >= J_beh), 0] += (v_est < v_beh).sum()\n", + " confmat[int(J_eval >= J_beh), 1] += (v_est >= v_beh).sum()\n", + " return confmat" + ] + }, + { + "cell_type": "markdown", + "id": "5947e759-6ccc-409c-b3f3-45757afecf75", + "metadata": {}, + "source": [ + "### OIS" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "fdef19bc-dcf8-4738-8d60-8ac401d2022d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== orig\n", + "ESS: 76.82151971964208\n", + "RMSE: 0.079±0.054\n", + "Spearman: 0.868±0.087\n", + "Accuracy: 79.8%±5.3% \t FPR: 6.4%±7.3% \t FNR: 30.3%±8.8%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
orig Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$51535
$$v(\\pi_e) \\geq v(\\pi_b)$$227523
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== naive uw\n", + "ESS: 207.20357031115466\n", + "RMSE: 0.150±0.007\n", + "Spearman: -0.196±0.163\n", + "Accuracy: 42.5%±3.3% \t FPR: 1.8%±5.7% \t FNR: 98.4%±3.7%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
naive uw Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$54010
$$v(\\pi_e) \\geq v(\\pi_b)$$73812
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== naive w\n", + "ESS: 300.7636760863779\n", + "RMSE: 0.110±0.008\n", + "Spearman: 0.415±0.168\n", + "Accuracy: 48.7%±9.7% \t FPR: 8.0%±11.2% \t FNR: 83.1%±20.0%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
naive w Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$50644
$$v(\\pi_e) \\geq v(\\pi_b)$$623127
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== annotEval\n", + "ESS: 993.9926736346977\n", + "RMSE: 0.013±0.005\n", + "Spearman: 0.995±0.003\n", + "Accuracy: 95.6%±3.2% \t FPR: 4.5%±6.9% \t FNR: 4.3%±5.5%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
annotEval Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$52525
$$v(\\pi_e) \\geq v(\\pi_b)$$32718
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== annotBeh\n", + "ESS: 993.9926736346977\n", + "RMSE: 0.070±0.003\n", + "Spearman: 0.962±0.012\n", + "Accuracy: 86.7%±8.3% \t FPR: 20.2%±20.1% \t FNR: 8.3%±11.4%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
annotBeh Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$439111
$$v(\\pi_e) \\geq v(\\pi_b)$$62688
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== annotZero\n", + "ESS: 993.9926736346977\n", + "RMSE: 0.166±0.008\n", + "Spearman: 0.925±0.016\n", + "Accuracy: 42.3%±0.0% \t FPR: 0.0%±0.0% \t FNR: 100.0%±0.0%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
annotZero Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$5500
$$v(\\pi_e) \\geq v(\\pi_b)$$7500
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== annotBehConvertedAM\n", + "ESS: 993.9926736346977\n", + "RMSE: 0.028±0.007\n", + "Spearman: 0.979±0.010\n", + "Accuracy: 90.1%±5.4% \t FPR: 4.2%±6.6% \t FNR: 14.1%±9.7%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
annotBehConvertedAM Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$52723
$$v(\\pi_e) \\geq v(\\pi_b)$$106644
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "all_results = {}\n", + "for name, df_results_lists in {\n", + " 'orig': dfs_results_orig_list,\n", + " 'naive uw': dfs_results_NaiveUW_list,\n", + " 'naive w': dfs_results_Naive_list,\n", + " 'annotEval': dfs_results_annotEval_list, \n", + " 'annotBeh': dfs_results_annotBeh_list,\n", + " 'annotZero': dfs_results_annotZero_list,\n", + " 'annotBehConvertedAM': dfs_results_annotBehConvertedAM_list,\n", + "}.items():\n", + " print('===', name)\n", + " v_IS = []\n", + " v_WIS = []\n", + " v_ESS = []\n", + " for π_eval, df_results_ in zip(πs_flip_list, df_results_lists):\n", + " v_IS.append(df_results_['IS_value'])\n", + " v_WIS.append(df_results_['WIS_value'])\n", + " v_ESS.append(df_results_['ESS1'])\n", + " print('ESS:', np.mean(v_ESS))\n", + " \n", + " v_est_list = v_IS\n", + " rmse_value = [rmse(l, np.array(v_list)) for l in np.array(v_est_list).T]\n", + " print('RMSE: {:.3f}±{:.3f}'.format(np.mean(rmse_value).round(3), np.std(rmse_value).round(3)))\n", + " \n", + " spearman_corr = [scipy.stats.spearmanr(l, np.array(v_list)).correlation for l in np.array(v_est_list).T]\n", + " print('Spearman: {:.3f}±{:.3f}'.format(np.mean(spearman_corr).round(3), np.std(spearman_corr).round(3)))\n", + " \n", + " confmats_ = [compute_confmat_batch(v_list, l, vb) for l,vb in zip(np.array(v_est_list).T, df_results_0['IS_value'])]\n", + " confmat_sum = sum(confmats_)\n", + " (accuracy, fpr, fnr) = (\n", + " [(cm[0,0]+cm[1,1])/np.sum(cm) for cm in confmats_],\n", + " [cm[0,1]/(cm[0,0]+cm[0,1]) for cm in confmats_], \n", + " [cm[1,0]/(cm[1,0]+cm[1,1]) for cm in confmats_],\n", + " )\n", + " print('Accuracy: {:.1%}±{:.1%} \\t FPR: {:.1%}±{:.1%} \\t FNR: {:.1%}±{:.1%}'.format(\n", + " np.mean(accuracy), np.std(accuracy), \n", + " np.mean(fpr), np.std(fpr), \n", + " np.mean(fnr), np.std(fnr), \n", + " ))\n", + " display(pd.DataFrame(\n", + " (confmat_sum).astype(int), \n", + " index=['$$v(\\pi_e) < v(\\pi_b)$$', '$$v(\\pi_e) \\geq v(\\pi_b)$$'],\n", + " columns=['$$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$', '$$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$'],\n", + " ).rename_axis(index='True Ranking', columns=f'{name} Predicted Ranking')\\\n", + " .style.background_gradient(cmap='Blues', vmin=0, vmax=750))\n", + " \n", + " all_results[name] = {\n", + " 'spearman': spearman_corr,\n", + " 'rmse': rmse_value,\n", + " 'accuracy': accuracy,\n", + " 'fpr': fpr,\n", + " 'fnr': fnr,\n", + " 'ess': v_ESS,\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "d6b0a6e1-40ac-4cdc-a8cf-ed05c5d75c78", + "metadata": {}, + "outputs": [], + "source": [ + "df_all_results = pd.DataFrame(all_results).T\n", + "df_all_results_mean = df_all_results.applymap(np.mean)\n", + "df_all_results_std = df_all_results.applymap(np.std)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "04f040b5-aafc-49d1-ab95-b99e776cb8cd", + "metadata": {}, + "outputs": [], + "source": [ + "df_all_results_summary = pd.DataFrame()\n", + "df_all_results_summary['rmse'] = \\\n", + " df_all_results_mean['rmse'].apply(lambda x: '{:.3f}'.format(x)) \\\n", + " .str.cat(df_all_results_std['rmse'].apply(lambda x: '{:.3f}'.format(x)), sep='±')\n", + "df_all_results_summary['ess'] = \\\n", + " df_all_results_mean['ess'].apply(lambda x: '{:.1f}'.format(x)) \\\n", + " .str.cat(df_all_results_std['ess'].apply(lambda x: '{:.1f}'.format(x)), sep='±')\n", + "df_all_results_summary['spearman'] = \\\n", + " df_all_results_mean['spearman'].apply(lambda x: '{:.3f}'.format(x)) \\\n", + " .str.cat(df_all_results_std['spearman'].apply(lambda x: '{:.3f}'.format(x)), sep='±')\n", + "df_all_results_summary['accuracy'] = \\\n", + " df_all_results_mean['accuracy'].apply(lambda x: '{:.1f}'.format(x*100)) \\\n", + " .str.cat(df_all_results_std['accuracy'].apply(lambda x: '{:.1f}'.format(x*100)), sep='±')\n", + "df_all_results_summary['fpr'] = \\\n", + " df_all_results_mean['fpr'].apply(lambda x: '{:.1f}'.format(x*100)) \\\n", + " .str.cat(df_all_results_std['fpr'].apply(lambda x: '{:.1f}'.format(x*100)), sep='±')\n", + "df_all_results_summary['fnr'] = \\\n", + " df_all_results_mean['fnr'].apply(lambda x: '{:.1f}'.format(x*100)) \\\n", + " .str.cat(df_all_results_std['fnr'].apply(lambda x: '{:.1f}'.format(x*100)), sep='±')" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "6a135717-be46-43d4-9bb3-d3ea540465e1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
rmseessspearmanaccuracyfprfnr
orig0.079±0.05476.8±44.00.868±0.08779.8±5.36.4±7.330.3±8.8
naive uw0.150±0.007207.2±91.5-0.196±0.16342.5±3.31.8±5.798.4±3.7
naive w0.110±0.008300.8±117.60.415±0.16848.7±9.78.0±11.283.1±20.0
annotEval0.013±0.005994.0±10.10.995±0.00395.6±3.24.5±6.94.3±5.5
annotBeh0.070±0.003994.0±10.10.962±0.01286.7±8.320.2±20.18.3±11.4
annotZero0.166±0.008994.0±10.10.925±0.01642.3±0.00.0±0.0100.0±0.0
annotBehConvertedAM0.028±0.007994.0±10.10.979±0.01090.1±5.44.2±6.614.1±9.7
\n", + "
" + ], + "text/plain": [ + " rmse ess spearman accuracy \\\n", + "orig 0.079±0.054 76.8±44.0 0.868±0.087 79.8±5.3 \n", + "naive uw 0.150±0.007 207.2±91.5 -0.196±0.163 42.5±3.3 \n", + "naive w 0.110±0.008 300.8±117.6 0.415±0.168 48.7±9.7 \n", + "annotEval 0.013±0.005 994.0±10.1 0.995±0.003 95.6±3.2 \n", + "annotBeh 0.070±0.003 994.0±10.1 0.962±0.012 86.7±8.3 \n", + "annotZero 0.166±0.008 994.0±10.1 0.925±0.016 42.3±0.0 \n", + "annotBehConvertedAM 0.028±0.007 994.0±10.1 0.979±0.010 90.1±5.4 \n", + "\n", + " fpr fnr \n", + "orig 6.4±7.3 30.3±8.8 \n", + "naive uw 1.8±5.7 98.4±3.7 \n", + "naive w 8.0±11.2 83.1±20.0 \n", + "annotEval 4.5±6.9 4.3±5.5 \n", + "annotBeh 20.2±20.1 8.3±11.4 \n", + "annotZero 0.0±0.0 100.0±0.0 \n", + "annotBehConvertedAM 4.2±6.6 14.1±9.7 " + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_all_results_summary[[\n", + " 'rmse', 'ess', 'spearman', 'accuracy', 'fpr', 'fnr'\n", + "]]" + ] + }, + { + "cell_type": "markdown", + "id": "81d32b3b-2988-4603-b9c1-ea19f704c62c", + "metadata": {}, + "source": [ + "### WIS" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "05795f4a-d1fa-4cfc-b43a-7cb0639b2f13", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== orig\n", + "ESS: 76.82151971964208\n", + "RMSE: 0.136±0.033\n", + "Spearman: 0.523±0.178\n", + "Accuracy: 73.2%±5.1% \t FPR: 61.1%±13.9% \t FNR: 1.6%±3.4%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
orig Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$214336
$$v(\\pi_e) \\geq v(\\pi_b)$$12738
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== naive uw\n", + "ESS: 207.20357031115466\n", + "RMSE: 0.102±0.007\n", + "Spearman: 0.478±0.125\n", + "Accuracy: 57.6%±11.2% \t FPR: 21.5%±15.2% \t FNR: 57.7%±25.9%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
naive uw Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$432118
$$v(\\pi_e) \\geq v(\\pi_b)$$433317
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== naive w\n", + "ESS: 300.7636760863779\n", + "RMSE: 0.082±0.006\n", + "Spearman: 0.771±0.073\n", + "Accuracy: 79.9%±6.7% \t FPR: 40.0%±17.5% \t FNR: 5.5%±6.1%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
naive w Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$330220
$$v(\\pi_e) \\geq v(\\pi_b)$$41709
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== annotEval\n", + "ESS: 993.9926736346977\n", + "RMSE: 0.013±0.005\n", + "Spearman: 0.995±0.003\n", + "Accuracy: 95.7%±3.0% \t FPR: 4.5%±6.9% \t FNR: 4.1%±5.1%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
annotEval Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$52525
$$v(\\pi_e) \\geq v(\\pi_b)$$31719
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== annotBeh\n", + "ESS: 993.9926736346977\n", + "RMSE: 0.070±0.003\n", + "Spearman: 0.961±0.012\n", + "Accuracy: 86.9%±8.1% \t FPR: 19.8%±20.1% \t FNR: 8.1%±11.2%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
annotBeh Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$441109
$$v(\\pi_e) \\geq v(\\pi_b)$$61689
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== annotZero\n", + "ESS: 993.9926736346977\n", + "RMSE: 0.166±0.008\n", + "Spearman: 0.925±0.016\n", + "Accuracy: 42.3%±0.0% \t FPR: 0.0%±0.0% \t FNR: 100.0%±0.0%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
annotZero Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$5500
$$v(\\pi_e) \\geq v(\\pi_b)$$7500
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== annotBehConvertedAM\n", + "ESS: 993.9926736346977\n", + "RMSE: 0.028±0.007\n", + "Spearman: 0.979±0.010\n", + "Accuracy: 90.1%±5.4% \t FPR: 4.2%±6.6% \t FNR: 14.1%±9.7%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
annotBehConvertedAM Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$52723
$$v(\\pi_e) \\geq v(\\pi_b)$$106644
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "all_results = {}\n", + "for name, df_results_lists in {\n", + " 'orig': dfs_results_orig_list,\n", + " 'naive uw': dfs_results_NaiveUW_list,\n", + " 'naive w': dfs_results_Naive_list,\n", + " 'annotEval': dfs_results_annotEval_list, \n", + " 'annotBeh': dfs_results_annotBeh_list,\n", + " 'annotZero': dfs_results_annotZero_list,\n", + " 'annotBehConvertedAM': dfs_results_annotBehConvertedAM_list,\n", + "}.items():\n", + " print('===', name)\n", + " v_IS = []\n", + " v_WIS = []\n", + " v_ESS = []\n", + " for π_eval, df_results_ in zip(πs_flip_list, df_results_lists):\n", + " v_IS.append(df_results_['IS_value'])\n", + " v_WIS.append(df_results_['WIS_value'])\n", + " v_ESS.append(df_results_['ESS1'])\n", + " print('ESS:', np.mean(v_ESS))\n", + " \n", + " v_est_list = v_WIS\n", + " rmse_value = [rmse(l, np.array(v_list)) for l in np.array(v_est_list).T]\n", + " print('RMSE: {:.3f}±{:.3f}'.format(np.mean(rmse_value).round(3), np.std(rmse_value).round(3)))\n", + " \n", + " spearman_corr = [scipy.stats.spearmanr(l, np.array(v_list)).correlation for l in np.array(v_est_list).T]\n", + " print('Spearman: {:.3f}±{:.3f}'.format(np.mean(spearman_corr).round(3), np.std(spearman_corr).round(3)))\n", + " \n", + " confmats_ = [compute_confmat_batch(v_list, l, vb) for l,vb in zip(np.array(v_est_list).T, df_results_0['IS_value'])]\n", + " confmat_sum = sum(confmats_)\n", + " (accuracy, fpr, fnr) = (\n", + " [(cm[0,0]+cm[1,1])/np.sum(cm) for cm in confmats_],\n", + " [cm[0,1]/(cm[0,0]+cm[0,1]) for cm in confmats_], \n", + " [cm[1,0]/(cm[1,0]+cm[1,1]) for cm in confmats_],\n", + " )\n", + " print('Accuracy: {:.1%}±{:.1%} \\t FPR: {:.1%}±{:.1%} \\t FNR: {:.1%}±{:.1%}'.format(\n", + " np.mean(accuracy), np.std(accuracy), \n", + " np.mean(fpr), np.std(fpr), \n", + " np.mean(fnr), np.std(fnr), \n", + " ))\n", + " display(pd.DataFrame(\n", + " (confmat_sum).astype(int), \n", + " index=['$$v(\\pi_e) < v(\\pi_b)$$', '$$v(\\pi_e) \\geq v(\\pi_b)$$'],\n", + " columns=['$$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$', '$$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$'],\n", + " ).rename_axis(index='True Ranking', columns=f'{name} Predicted Ranking')\\\n", + " .style.background_gradient(cmap='Blues', vmin=0, vmax=750))\n", + " \n", + " all_results[name] = {\n", + " 'spearman': spearman_corr,\n", + " 'rmse': rmse_value,\n", + " 'accuracy': accuracy,\n", + " 'fpr': fpr,\n", + " 'fnr': fnr,\n", + " 'ess': v_ESS,\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "1e032efb-ce12-411b-8ebc-67cb8782f257", + "metadata": {}, + "outputs": [], + "source": [ + "df_all_results = pd.DataFrame(all_results).T\n", + "df_all_results_mean = df_all_results.applymap(np.mean)\n", + "df_all_results_std = df_all_results.applymap(np.std)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "94fb9447-349c-465c-af98-1aad1ad6805a", + "metadata": {}, + "outputs": [], + "source": [ + "df_all_results_summary = pd.DataFrame()\n", + "df_all_results_summary['rmse'] = \\\n", + " df_all_results_mean['rmse'].apply(lambda x: '{:.3f}'.format(x)) \\\n", + " .str.cat(df_all_results_std['rmse'].apply(lambda x: '{:.3f}'.format(x)), sep='±')\n", + "df_all_results_summary['ess'] = \\\n", + " df_all_results_mean['ess'].apply(lambda x: '{:.1f}'.format(x)) \\\n", + " .str.cat(df_all_results_std['ess'].apply(lambda x: '{:.1f}'.format(x)), sep='±')\n", + "df_all_results_summary['spearman'] = \\\n", + " df_all_results_mean['spearman'].apply(lambda x: '{:.3f}'.format(x)) \\\n", + " .str.cat(df_all_results_std['spearman'].apply(lambda x: '{:.3f}'.format(x)), sep='±')\n", + "df_all_results_summary['accuracy'] = \\\n", + " df_all_results_mean['accuracy'].apply(lambda x: '{:.1f}'.format(x*100)) \\\n", + " .str.cat(df_all_results_std['accuracy'].apply(lambda x: '{:.1f}'.format(x*100)), sep='±')\n", + "df_all_results_summary['fpr'] = \\\n", + " df_all_results_mean['fpr'].apply(lambda x: '{:.1f}'.format(x*100)) \\\n", + " .str.cat(df_all_results_std['fpr'].apply(lambda x: '{:.1f}'.format(x*100)), sep='±')\n", + "df_all_results_summary['fnr'] = \\\n", + " df_all_results_mean['fnr'].apply(lambda x: '{:.1f}'.format(x*100)) \\\n", + " .str.cat(df_all_results_std['fnr'].apply(lambda x: '{:.1f}'.format(x*100)), sep='±')" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "14204e4f-9a70-4701-bb87-7ca708fb1a9a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
rmseessspearmanaccuracyfprfnr
orig0.136±0.03376.8±44.00.523±0.17873.2±5.161.1±13.91.6±3.4
naive uw0.102±0.007207.2±91.50.478±0.12557.6±11.221.5±15.257.7±25.9
naive w0.082±0.006300.8±117.60.771±0.07379.9±6.740.0±17.55.5±6.1
annotEval0.013±0.005994.0±10.10.995±0.00395.7±3.04.5±6.94.1±5.1
annotBeh0.070±0.003994.0±10.10.961±0.01286.9±8.119.8±20.18.1±11.2
annotZero0.166±0.008994.0±10.10.925±0.01642.3±0.00.0±0.0100.0±0.0
annotBehConvertedAM0.028±0.007994.0±10.10.979±0.01090.1±5.44.2±6.614.1±9.7
\n", + "
" + ], + "text/plain": [ + " rmse ess spearman accuracy \\\n", + "orig 0.136±0.033 76.8±44.0 0.523±0.178 73.2±5.1 \n", + "naive uw 0.102±0.007 207.2±91.5 0.478±0.125 57.6±11.2 \n", + "naive w 0.082±0.006 300.8±117.6 0.771±0.073 79.9±6.7 \n", + "annotEval 0.013±0.005 994.0±10.1 0.995±0.003 95.7±3.0 \n", + "annotBeh 0.070±0.003 994.0±10.1 0.961±0.012 86.9±8.1 \n", + "annotZero 0.166±0.008 994.0±10.1 0.925±0.016 42.3±0.0 \n", + "annotBehConvertedAM 0.028±0.007 994.0±10.1 0.979±0.010 90.1±5.4 \n", + "\n", + " fpr fnr \n", + "orig 61.1±13.9 1.6±3.4 \n", + "naive uw 21.5±15.2 57.7±25.9 \n", + "naive w 40.0±17.5 5.5±6.1 \n", + "annotEval 4.5±6.9 4.1±5.1 \n", + "annotBeh 19.8±20.1 8.1±11.2 \n", + "annotZero 0.0±0.0 100.0±0.0 \n", + "annotBehConvertedAM 4.2±6.6 14.1±9.7 " + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_all_results_summary[[\n", + " 'rmse', 'ess', 'spearman', 'accuracy', 'fpr', 'fnr'\n", + "]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d81ee6a8-4b55-4487-ae78-14bcbc535718", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sepsisSim/experiments/results.ipynb b/sepsisSim/experiments/results.ipynb new file mode 100644 index 0000000..ef97fef --- /dev/null +++ b/sepsisSim/experiments/results.ipynb @@ -0,0 +1,6065 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "89b14dad-9e23-4f54-8bdf-dd5eebd8c025", + "metadata": {}, + "outputs": [], + "source": [ + "# ## Simulation parameters\n", + "exp_name = 'exp-FINAL'\n", + "eps = 0.10\n", + "eps_str = '0_1'\n", + "\n", + "run_idx_length = 1_000\n", + "N_val = 1_000\n", + "runs = 50" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "a510145e-f9df-4460-a15d-377f2b4a6072", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib\n", + "matplotlib.rcParams['font.sans-serif'] = ['Arial']\n", + "%config InlineBackend.figure_formats = ['svg']" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "11339963-17b7-44f7-9823-59ac511a4a8a", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "from collections import defaultdict\n", + "import pickle\n", + "import itertools\n", + "import copy\n", + "import random\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import scipy.stats\n", + "from sklearn import metrics\n", + "import itertools\n", + "\n", + "import joblib\n", + "from joblib import Parallel, delayed" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0d61671d-b32e-4ec5-8859-85a5fae5c3c4", + "metadata": {}, + "outputs": [], + "source": [ + "from OPE_utils_new import (\n", + " format_data_tensor,\n", + " policy_eval_analytic_finite,\n", + " OPE_IS_h,\n", + " compute_behavior_policy_h,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0c573be0-9afb-4dc0-b018-4c62a0ce15f7", + "metadata": {}, + "outputs": [], + "source": [ + "def policy_eval_helper(π):\n", + " V_H = policy_eval_analytic_finite(P.transpose((1,0,2)), R, π, gamma, H)\n", + " Q_H = [(R + gamma * P.transpose((1,0,2)) @ V_H[h]) for h in range(1,H)] + [R]\n", + " J = isd @ V_H[0]\n", + " # Check recursive relationships\n", + " assert len(Q_H) == H\n", + " assert len(V_H) == H\n", + " assert np.all(Q_H[-1] == R)\n", + " assert np.all(np.sum(π * Q_H[-1], axis=1) == V_H[-1])\n", + " assert np.all(R + gamma * P.transpose((1,0,2)) @ V_H[-1] == Q_H[-2])\n", + " return V_H, Q_H, J" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a0ab1c7c-bcd3-43ad-848b-93268fa1a2b6", + "metadata": {}, + "outputs": [], + "source": [ + "def iqm(x):\n", + " return scipy.stats.trim_mean(x, proportiontocut=0.25, axis=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "467e2e02-2f18-4fa9-9d17-3b9ae0e4b8dd", + "metadata": {}, + "outputs": [], + "source": [ + "NSTEPS = H = 20 # max episode length in historical data # Horizon of the MDP\n", + "G_min = -1 # the minimum possible return\n", + "G_max = 1 # the maximum possible return\n", + "nS, nA = 1442, 8\n", + "\n", + "PROB_DIAB = 0.2" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fdf1dfc0-7d33-4f7d-952c-5b77668f2968", + "metadata": {}, + "outputs": [], + "source": [ + "# Ground truth MDP model\n", + "MDP_parameters = joblib.load('../data/MDP_parameters.joblib')\n", + "P = MDP_parameters['transition_matrix_absorbing'] # (A, S, S_next)\n", + "R = MDP_parameters['reward_matrix_absorbing_SA'] # (S, A)\n", + "nS, nA = R.shape\n", + "gamma = 0.99\n", + "\n", + "# unif rand isd, mixture of diabetic state\n", + "isd = joblib.load('../data/modified_prior_initial_state_absorbing.joblib')\n", + "isd = (isd > 0).astype(float)\n", + "isd[:720] = isd[:720] / isd[:720].sum() * (1-PROB_DIAB)\n", + "isd[720:] = isd[720:] / isd[720:].sum() * (PROB_DIAB)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c210ea9b-a6e0-42e6-be21-35a8e439c0c8", + "metadata": {}, + "outputs": [], + "source": [ + "# Precomputed optimal policy\n", + "π_star = joblib.load('../data/π_star.joblib')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "9342ffbd-00f1-4916-b418-054a289f20bd", + "metadata": { + "tags": [] + }, + "source": [ + "## Policies" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "645c4315-1109-4b41-b7ec-177e03e3343d", + "metadata": {}, + "outputs": [], + "source": [ + "# vaso unif, mv abx optimal\n", + "π_unif = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "32ece6ab-6dca-479a-ab85-0cc3233e9f72", + "metadata": {}, + "source": [ + "### Behavior policy" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "639c3b25-c715-4a90-8d1a-094e4478b9c9", + "metadata": {}, + "outputs": [], + "source": [ + "# vaso eps=0.5, mv abx optimal\n", + "π_beh = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)\n", + "π_beh[π_star == 1] = 1-eps\n", + "π_beh[π_beh == 0.5] = eps" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "50a792f2-1f9d-40f7-89db-bd9245eabdf3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.25038354793851164" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V_H_beh, Q_H_beh, J_beh = policy_eval_helper(π_beh)\n", + "J_beh" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "8d88637a-8298-4f1f-88ef-ab11b05ec08b", + "metadata": {}, + "source": [ + "### Optimal policy" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "d849f76e-01a1-48f8-9d8a-e9d0a880cadf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.40877179296760224" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V_H_star, Q_H_star, J_star = policy_eval_helper(π_star)\n", + "J_star" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "44136cb8-d2e0-459d-8ee1-f15b2ed8a62c", + "metadata": {}, + "source": [ + "### flip action for x% states" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "97e67e52-09c6-4f0d-b435-1304a88a98d6", + "metadata": {}, + "outputs": [], + "source": [ + "def get_π_flip(pol_flip_seed, pol_flip_num):\n", + " rng_flip = np.random.default_rng(pol_flip_seed)\n", + " flip_states = rng_flip.choice(range(1440), pol_flip_num, replace=False)\n", + "\n", + " π_tmp = (np.tile(π_star.reshape((-1,2,2,2)).sum(axis=3, keepdims=True), (1,1,1,2)).reshape((-1, 8)) / 2)\n", + " π_flip = π_tmp.copy()\n", + " π_flip[π_tmp == 0.5] = 0\n", + " π_flip[π_star == 1] = 1\n", + " for s in flip_states:\n", + " π_flip[s, π_tmp[s] == 0.5] = 1\n", + " π_flip[s, π_star[s] == 1] = 0\n", + " assert π_flip.sum(axis=1).mean() == 1\n", + " return π_flip" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "67c8441f-07bc-425a-ad5b-26b970653986", + "metadata": {}, + "outputs": [], + "source": [ + "πs_flip_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " π_flip = get_π_flip(flip_seed, flip_num)\n", + " πs_flip_list.append(π_flip)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "976a7213-6cb1-4511-a888-80173c967581", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00, 1.69it/s]\n" + ] + } + ], + "source": [ + "v_list = []\n", + "for π_eval in tqdm(πs_flip_list):\n", + " _, _, J_eval = policy_eval_helper(π_eval)\n", + " v_list.append(J_eval)" + ] + }, + { + "cell_type": "code", + "execution_count": 204, + "id": "c421cb40-2069-4d04-be56-2456dc6810b4", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-05-21T11:45:16.828337\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(4,3))\n", + "plt.plot(v_list, ls='none', marker='.')\n", + "plt.axhline(J_beh, c='k', ls='--', label='behavior', zorder=0)\n", + "plt.xlabel('policy index')\n", + "plt.ylabel('policy value')\n", + "plt.savefig('fig/sepsisSim-policies.pdf', bbox_inches='tight')\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "62b4937e-484a-4740-90ea-8cf91fd19578", + "metadata": {}, + "source": [ + "## Load results" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "73d56e0a-de4a-4a90-9c3a-c4b78a02e64b", + "metadata": {}, + "outputs": [], + "source": [ + "df_results_0 = pd.read_csv(f'./results/{exp_name}/vaso_eps_{eps_str}-observed.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "id": "1d62f73d-34b2-42c4-96e8-5992765189a6", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_Naive_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-1/vaso_eps_{eps_str}-{pol_name}-aug_step-Naive.csv')\n", + " dfs_results_Naive_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "6a3bbe16-6ad9-4131-b42f-e3d8a9525851", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_NaiveUW_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-1/vaso_eps_{eps_str}-{pol_name}-aug_step-NaiveUW.csv')\n", + " dfs_results_NaiveUW_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "d23aed3c-898f-4c44-b32b-20f744ab3a7d", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_orig_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-1/vaso_eps_{eps_str}-{pol_name}-orig.csv')\n", + " dfs_results_orig_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "074e313e-ccc4-475f-857b-12b83da8b48d", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotEval_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-2/vaso_eps_{eps_str}-{pol_name}-aug_step-annotEval.csv')\n", + " dfs_results_annotEval_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 162, + "id": "6cb0164e-1c33-486c-ad59-ad1ebf9a2e01", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotBeh_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-2/vaso_eps_{eps_str}-{pol_name}-aug_step-annotBeh.csv')\n", + " dfs_results_annotBeh_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 161, + "id": "46613104-2d46-4110-816e-174dd9160a56", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotZero_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-2/vaso_eps_{eps_str}-{pol_name}-aug_step-annotZero.csv')\n", + " dfs_results_annotZero_list.append(df_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": 163, + "id": "2796297e-d257-4288-a73d-46999b565983", + "metadata": {}, + "outputs": [], + "source": [ + "dfs_results_annotBehConvertedAM_list = []\n", + "for flip_num in [0, 50, 100, 200, 300, 400]:\n", + " for flip_seed in [0, 42, 123, 424242, 10000]:\n", + " if flip_num == 0 and flip_seed != 0: continue\n", + " pol_name = f'flip{flip_num}_seed{flip_seed}'\n", + " df_results_ = pd.read_csv(f'./results/{exp_name}-5/vaso_eps_{eps_str}-{pol_name}-aug_step-annotBehConvertedAM.csv')\n", + " dfs_results_annotBehConvertedAM_list.append(df_results_)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "1f8bff6a-1a52-4b51-8d6e-89c8c6ffd5c6", + "metadata": {}, + "source": [ + "## Plots" + ] + }, + { + "cell_type": "code", + "execution_count": 164, + "id": "a0802fdd-7fe9-405b-9a13-2557a2065a11", + "metadata": {}, + "outputs": [], + "source": [ + "exp_idx = 0\n", + "π_eval = πs_flip_list[exp_idx]\n", + "J_eval = v_list[exp_idx]\n", + "df_results_orig = dfs_results_orig_list[exp_idx]\n", + "df_results_Naive = dfs_results_Naive_list[exp_idx]\n", + "df_results_NaiveUW = dfs_results_NaiveUW_list[exp_idx]\n", + "df_results_annotEval = dfs_results_annotEval_list[exp_idx]\n", + "df_results_annotBeh = dfs_results_annotBeh_list[exp_idx]\n", + "df_results_annotZero = dfs_results_annotZero_list[exp_idx]\n", + "df_results_annotBehConvertedAM = dfs_results_annotBehConvertedAM_list[exp_idx]" + ] + }, + { + "cell_type": "code", + "execution_count": 166, + "id": "652094bf-2cec-4f04-8c65-39d62b3a5831", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-05-08T09:58:25.277753\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.2, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(4,4))\n", + "plt.axhline(J_beh, c='k', ls='--', label='behavior', zorder=0)\n", + "plt.scatter(N_val, J_beh, marker='o', s=40, c='k', alpha=1)\n", + "plt.axhline(J_eval, c='k', ls=':', label='true', zorder=0)\n", + "plt.scatter([N_val]*runs, df_results_0['IS_value'], \n", + " marker='o', s=10, c='k', alpha=0.25, ec='none')\n", + "\n", + "for df, name, color in [\n", + " [df_results_orig, 'original', 'tab:blue'],\n", + " [df_results_Naive, 'naive w', 'tab:cyan'],\n", + " [df_results_NaiveUW, 'naive uw', 'tab:olive'],\n", + " [df_results_annotEval, 'prop. eval', 'tab:green'],\n", + " [df_results_annotBeh, 'prop. beh', 'tab:orange'],\n", + " [df_results_annotZero, 'prop. zero', 'tab:pink'],\n", + " [df_results_annotBehConvertedAM, 'prop. beh conv', 'gold'],\n", + "]:\n", + " plt.plot(0,0, c=color, label=name)\n", + " plt.scatter(df['ESS1'], df['IS_value'], marker='o', s=10, c=color, alpha=0.25, ec='none')\n", + " plt.scatter(iqm(df['ESS1']), iqm(df['IS_value']), marker='o', s=40, c=color, alpha=0.8)\n", + " plt.scatter(df['ESS1'], df['WIS_value'], marker='X', s=10, c=color, alpha=0.25, ec='none')\n", + " plt.scatter(iqm(df['ESS1']), iqm(df['WIS_value']), marker='X', s=40, c=color, alpha=0.8)\n", + "\n", + "plt.scatter(-100,0, c='gray', marker='o', label='IS')\n", + "plt.scatter(-100,0, c='gray', marker='X', label='WIS')\n", + "plt.xlabel('ESS')\n", + "plt.ylabel('OPE value')\n", + "# plt.ylim(0.1, 0.7)\n", + "plt.xlim(0, N_val*1.25)\n", + "plt.legend(bbox_to_anchor=(1.04, 1))\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "da761092-92b5-439c-aa25-9efd0ddb674a", + "metadata": {}, + "source": [ + "## Metrics" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "980a4d9d-038e-43e0-a19b-6d5cd8d4c1c1", + "metadata": {}, + "outputs": [], + "source": [ + "def rmse(y1, y2):\n", + " return np.sqrt(np.mean(np.square(y1-y2)))" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "949d31e2-c9c1-4f22-8169-4e8befa02ed2", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_confmat(J_eval, v_est_list):\n", + " confmat = np.zeros((2,2))\n", + " confmat[int(J_eval >= J_beh), 0] += (v_est_list < df_results_0['IS_value'].iloc[:len(v_est_list)]).sum()\n", + " confmat[int(J_eval >= J_beh), 1] += (v_est_list >= df_results_0['IS_value'].iloc[:len(v_est_list)]).sum()\n", + " return confmat" + ] + }, + { + "cell_type": "code", + "execution_count": 149, + "id": "84027389-4866-4c64-a882-ec9448a5c71d", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_confmat_batch(v_list, v_est_list, v_beh):\n", + " confmat = np.zeros((2,2))\n", + " for J_eval, v_est in zip(v_list, v_est_list):\n", + " confmat[int(J_eval >= J_beh), 0] += (v_est < v_beh).sum()\n", + " confmat[int(J_eval >= J_beh), 1] += (v_est >= v_beh).sum()\n", + " return confmat" + ] + }, + { + "cell_type": "code", + "execution_count": 168, + "id": "fdef19bc-dcf8-4738-8d60-8ac401d2022d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== orig\n", + "ESS: 76.82151971964208\n", + "RMSE: 0.113±0.038\n", + "Spearman: 0.596±0.110\n", + "Accuracy: 76.5%±3.5% \t FPR: 33.7%±8.7% \t FNR: 15.9%±4.6%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
orig Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$729371
$$v(\\pi_e) \\geq v(\\pi_b)$$2391261
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== naive uw\n", + "ESS: 207.20357031115466\n", + "RMSE: 0.128±0.006\n", + "Spearman: 0.089±0.089\n", + "Accuracy: 50.0%±6.0% \t FPR: 11.6%±8.3% \t FNR: 78.1%±13.6%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
naive uw Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$972128
$$v(\\pi_e) \\geq v(\\pi_b)$$1171329
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== naive w\n", + "ESS: 300.7636760863779\n", + "RMSE: 0.097±0.006\n", + "Spearman: 0.420±0.097\n", + "Accuracy: 64.3%±4.7% \t FPR: 24.0%±12.7% \t FNR: 44.3%±11.4%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
naive w Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$836264
$$v(\\pi_e) \\geq v(\\pi_b)$$664836
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== annotEval\n", + "ESS: 993.9926736346977\n", + "RMSE: 0.013±0.005\n", + "Spearman: 0.995±0.003\n", + "Accuracy: 95.7%±3.1% \t FPR: 4.5%±6.9% \t FNR: 4.2%±5.3%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
annotEval Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$105050
$$v(\\pi_e) \\geq v(\\pi_b)$$631437
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== annotBeh\n", + "ESS: 993.9926736346977\n", + "RMSE: 0.070±0.003\n", + "Spearman: 0.961±0.011\n", + "Accuracy: 86.8%±8.2% \t FPR: 20.0%±20.1% \t FNR: 8.2%±11.3%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
annotBeh Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$880220
$$v(\\pi_e) \\geq v(\\pi_b)$$1231377
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== annotZero\n", + "ESS: 993.9926736346977\n", + "RMSE: 0.166±0.008\n", + "Spearman: 0.925±0.016\n", + "Accuracy: 42.3%±0.0% \t FPR: 0.0%±0.0% \t FNR: 100.0%±0.0%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
annotZero Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$11000
$$v(\\pi_e) \\geq v(\\pi_b)$$15000
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=== annotBehConvertedAM\n", + "ESS: 993.9926736346977\n", + "RMSE: 0.028±0.007\n", + "Spearman: 0.979±0.010\n", + "Accuracy: 90.1%±5.4% \t FPR: 4.2%±6.6% \t FNR: 14.1%±9.7%\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
annotBehConvertedAM Predicted Ranking $$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$ $$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$
True Ranking
$$v(\\pi_e) < v(\\pi_b)$$105446
$$v(\\pi_e) \\geq v(\\pi_b)$$2121288
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "all_results = {}\n", + "for name, df_results_lists in {\n", + " 'orig': dfs_results_orig_list,\n", + " 'naive uw': dfs_results_NaiveUW_list,\n", + " 'naive w': dfs_results_Naive_list,\n", + " 'annotEval': dfs_results_annotEval_list, \n", + " 'annotBeh': dfs_results_annotBeh_list,\n", + " 'annotZero': dfs_results_annotZero_list,\n", + " 'annotBehConvertedAM': dfs_results_annotBehConvertedAM_list,\n", + "}.items():\n", + " print('===', name)\n", + " v_IS = []\n", + " v_WIS = []\n", + " v_ESS = []\n", + " for π_eval, df_results_ in zip(πs_flip_list, df_results_lists):\n", + " v_IS.append(df_results_['IS_value'])\n", + " v_WIS.append(df_results_['WIS_value'])\n", + " v_ESS.append(df_results_['ESS1'])\n", + " print('ESS:', np.mean(v_ESS))\n", + " \n", + " v_est_list = v_IS + v_WIS\n", + " rmse_value = [rmse(l, np.array(v_list+v_list)) for l in np.array(v_est_list).T]\n", + " print('RMSE: {:.3f}±{:.3f}'.format(np.mean(rmse_value).round(3), np.std(rmse_value).round(3)))\n", + " \n", + " spearman_corr = [scipy.stats.spearmanr(l, np.array(v_list+v_list)).correlation for l in np.array(v_est_list).T]\n", + " print('Spearman: {:.3f}±{:.3f}'.format(np.mean(spearman_corr).round(3), np.std(spearman_corr).round(3)))\n", + " \n", + " confmats_ = [compute_confmat_batch(v_list+v_list, l, vb) for l,vb in zip(np.array(v_est_list).T, df_results_0['IS_value'])]\n", + " confmat_sum = sum(confmats_)\n", + " (accuracy, fpr, fnr) = (\n", + " [(cm[0,0]+cm[1,1])/np.sum(cm) for cm in confmats_],\n", + " [cm[0,1]/(cm[0,0]+cm[0,1]) for cm in confmats_], \n", + " [cm[1,0]/(cm[1,0]+cm[1,1]) for cm in confmats_],\n", + " )\n", + " print('Accuracy: {:.1%}±{:.1%} \\t FPR: {:.1%}±{:.1%} \\t FNR: {:.1%}±{:.1%}'.format(\n", + " np.mean(accuracy), np.std(accuracy), \n", + " np.mean(fpr), np.std(fpr), \n", + " np.mean(fnr), np.std(fnr), \n", + " ))\n", + " display(pd.DataFrame(\n", + " (confmat_sum).astype(int), \n", + " index=['$$v(\\pi_e) < v(\\pi_b)$$', '$$v(\\pi_e) \\geq v(\\pi_b)$$'],\n", + " columns=['$$\\hat{v}(\\pi_e) < \\hat{v}(\\pi_b)$$', '$$\\hat{v}(\\pi_e) \\geq \\hat{v}(\\pi_b)$$'],\n", + " ).rename_axis(index='True Ranking', columns=f'{name} Predicted Ranking')\\\n", + " .style.background_gradient(cmap='Blues', vmin=0, vmax=1500))\n", + " \n", + " all_results[name] = {\n", + " 'spearman': spearman_corr,\n", + " 'rmse': rmse_value,\n", + " 'accuracy': accuracy,\n", + " 'fpr': fpr,\n", + " 'fnr': fnr,\n", + " 'ess': v_ESS,\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 185, + "id": "d6b0a6e1-40ac-4cdc-a8cf-ed05c5d75c78", + "metadata": {}, + "outputs": [], + "source": [ + "df_all_results = pd.DataFrame(all_results).T\n", + "df_all_results_mean = df_all_results.applymap(np.mean)\n", + "df_all_results_std = df_all_results.applymap(np.std)" + ] + }, + { + "cell_type": "code", + "execution_count": 201, + "id": "04f040b5-aafc-49d1-ab95-b99e776cb8cd", + "metadata": {}, + "outputs": [], + "source": [ + "df_all_results_summary = pd.DataFrame()\n", + "df_all_results_summary['rmse'] = \\\n", + " df_all_results_mean['rmse'].apply(lambda x: '{:.3f}'.format(x)) \\\n", + " .str.cat(df_all_results_std['rmse'].apply(lambda x: '{:.3f}'.format(x)), sep='±')\n", + "df_all_results_summary['ess'] = \\\n", + " df_all_results_mean['ess'].apply(lambda x: '{:.1f}'.format(x)) \\\n", + " .str.cat(df_all_results_std['ess'].apply(lambda x: '{:.1f}'.format(x)), sep='±')\n", + "df_all_results_summary['spearman'] = \\\n", + " df_all_results_mean['spearman'].apply(lambda x: '{:.3f}'.format(x)) \\\n", + " .str.cat(df_all_results_std['spearman'].apply(lambda x: '{:.3f}'.format(x)), sep='±')\n", + "df_all_results_summary['accuracy'] = \\\n", + " df_all_results_mean['accuracy'].apply(lambda x: '{:.1f}'.format(x*100)) \\\n", + " .str.cat(df_all_results_std['accuracy'].apply(lambda x: '{:.1f}'.format(x*100)), sep='±')\n", + "df_all_results_summary['fpr'] = \\\n", + " df_all_results_mean['fpr'].apply(lambda x: '{:.1f}'.format(x*100)) \\\n", + " .str.cat(df_all_results_std['fpr'].apply(lambda x: '{:.1f}'.format(x*100)), sep='±')\n", + "df_all_results_summary['fnr'] = \\\n", + " df_all_results_mean['fnr'].apply(lambda x: '{:.1f}'.format(x*100)) \\\n", + " .str.cat(df_all_results_std['fnr'].apply(lambda x: '{:.1f}'.format(x*100)), sep='±')" + ] + }, + { + "cell_type": "code", + "execution_count": 202, + "id": "6a135717-be46-43d4-9bb3-d3ea540465e1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
rmseessspearmanaccuracyfprfnr
orig0.113±0.03876.8±44.00.596±0.11076.5±3.533.7±8.715.9±4.6
naive uw0.128±0.006207.2±91.50.089±0.08950.0±6.011.6±8.378.1±13.6
naive w0.097±0.006300.8±117.60.420±0.09764.3±4.724.0±12.744.3±11.4
annotEval0.013±0.005994.0±10.10.995±0.00395.7±3.14.5±6.94.2±5.3
annotBeh0.070±0.003994.0±10.10.961±0.01186.8±8.220.0±20.18.2±11.3
annotZero0.166±0.008994.0±10.10.925±0.01642.3±0.00.0±0.0100.0±0.0
annotBehConvertedAM0.028±0.007994.0±10.10.979±0.01090.1±5.44.2±6.614.1±9.7
\n", + "
" + ], + "text/plain": [ + " rmse ess spearman accuracy \\\n", + "orig 0.113±0.038 76.8±44.0 0.596±0.110 76.5±3.5 \n", + "naive uw 0.128±0.006 207.2±91.5 0.089±0.089 50.0±6.0 \n", + "naive w 0.097±0.006 300.8±117.6 0.420±0.097 64.3±4.7 \n", + "annotEval 0.013±0.005 994.0±10.1 0.995±0.003 95.7±3.1 \n", + "annotBeh 0.070±0.003 994.0±10.1 0.961±0.011 86.8±8.2 \n", + "annotZero 0.166±0.008 994.0±10.1 0.925±0.016 42.3±0.0 \n", + "annotBehConvertedAM 0.028±0.007 994.0±10.1 0.979±0.010 90.1±5.4 \n", + "\n", + " fpr fnr \n", + "orig 33.7±8.7 15.9±4.6 \n", + "naive uw 11.6±8.3 78.1±13.6 \n", + "naive w 24.0±12.7 44.3±11.4 \n", + "annotEval 4.5±6.9 4.2±5.3 \n", + "annotBeh 20.0±20.1 8.2±11.3 \n", + "annotZero 0.0±0.0 100.0±0.0 \n", + "annotBehConvertedAM 4.2±6.6 14.1±9.7 " + ] + }, + "execution_count": 202, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_all_results_summary[[\n", + " 'rmse', 'ess', 'spearman', 'accuracy', 'fpr', 'fnr'\n", + "]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d81ee6a8-4b55-4487-ae78-14bcbc535718", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/synthetic/bandit_compare-1state.ipynb b/synthetic/bandit_compare-1state.ipynb new file mode 100644 index 0000000..d44a14f --- /dev/null +++ b/synthetic/bandit_compare-1state.ipynb @@ -0,0 +1,872 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "a6eac7c9-b2bb-4995-bf22-a9d66093f9d6", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from IPython.display import display\n", + "import matplotlib.pyplot as plt\n", + "%config InlineBackend.figure_formats = ['svg']\n", + "import matplotlib\n", + "matplotlib.rcParams['text.usetex'] = True\n", + "matplotlib.rcParams['font.sans-serif'] = ['FreeSans']\n", + "import seaborn as sns\n", + "import itertools\n", + "from tqdm import tqdm\n", + "import joblib" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "16192b2c-09d3-49b7-98d1-edc8c5bbbd22", + "metadata": {}, + "outputs": [], + "source": [ + "d0 = np.array([1.])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "3b17c719-1120-434f-9785-e5d3583909e4", + "metadata": {}, + "outputs": [], + "source": [ + "Rs = [\n", + " np.array([[1., 2.],]),\n", + " np.array([[-1., 1.],]),\n", + " np.array([[-1., -2.],]),\n", + "]\n", + "sigmas = [\n", + " np.array([[0.5, 0.5],]),\n", + " np.array([[0.5, 0.5],]),\n", + " np.array([[0.5, 0.5],]),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a2617940-a7a0-4de0-b96c-0c4d87b4365c", + "metadata": {}, + "outputs": [], + "source": [ + "πs = [\n", + " np.array([[1., 0.],]),\n", + " np.array([[0., 1.],]),\n", + " np.array([[0.5, 0.5],]),\n", + " np.array([[0.1, 0.9],]),\n", + " np.array([[0.8, 0.2],]),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3b57307b-c5c4-4b26-be34-5e8d13a2ea20", + "metadata": {}, + "outputs": [], + "source": [ + "use_πD = False" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "815dabf9-cdff-4d7a-bac4-889d6b046442", + "metadata": {}, + "outputs": [], + "source": [ + "N = 1\n", + "runs = 1000" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a82feb16-b247-409b-b556-21e1bab36602", + "metadata": {}, + "outputs": [], + "source": [ + "def single_run():\n", + " np.random.seed(42)\n", + "\n", + " # True value of π_e\n", + " Js = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " x = rng.choice(1, size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_e[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " J = np.sum(r) / N\n", + " Js.append(J)\n", + "\n", + " # Standard IS\n", + " Gs = []\n", + " OISs = []\n", + " WISs = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " x = rng.choice(1, size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_b[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " G = np.sum(r) / N\n", + " Gs.append(G)\n", + "\n", + " if use_πD:\n", + " π_b_ = np.array([\n", + " [(np.sum((x==0)&(a==0)))/np.sum(x==0), \n", + " (np.sum((x==0)&(a==1)))/np.sum(x==0)],\n", + " ])\n", + " else:\n", + " π_b_ = π_b\n", + "\n", + " rho = π_e[x,a] / π_b_[x,a]\n", + " OISs.append(np.sum(rho * r) / N)\n", + " WISs.append(np.sum(rho * r) / np.sum(rho))\n", + "\n", + "\n", + " # Collect data using π_b - combining counterfactuals with factuals\n", + " FC_OISs_w = []\n", + " FC_WISs_w = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " rng_c = np.random.default_rng(seed=100000+seed)\n", + " x = rng.choice(1, size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_b[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " rho = π_e[x,a] / π_b[x,a]\n", + "\n", + " # counterfactual flag\n", + " c = np.array([rng_c.choice(2, p=[1-Pc[xi,ai], Pc[xi,ai]]) for xi,ai in zip(x,a)])\n", + "\n", + " # counterfactual reward\n", + " rc = np.array([rng_c.normal(R[xi,1-ai], sigma[xi,1-ai]) for xi,ai in zip(x,a)])\n", + " rc[c==0] = np.nan\n", + "\n", + " # trajectory-wise weight\n", + " w = np.ones(N)\n", + " w[c==1] = ww[x[c==1], a[c==1], a[c==1]]\n", + " wc = np.zeros(N)\n", + " wc[c==1] = ww[x[c==1], a[c==1], 1-a[c==1]]\n", + "\n", + " if use_πD:\n", + " # augmented behavior policy\n", + " π_b_ = np.array([\n", + " [(np.sum(w*((x==0)&(a==0)))+np.sum(wc*((x==0)&(a==1)&(c==1))))/(np.sum(w*(x==0))+np.sum(wc*((x==0)&(c==1)))), \n", + " (np.sum(w*((x==0)&(a==1)))+np.sum(wc*((x==0)&(a==0)&(c==1))))/(np.sum(w*(x==0))+np.sum(wc*((x==0)&(c==1))))],\n", + " ])\n", + " else:\n", + " # augmented behavior policy\n", + " π_b_ = np.array([\n", + " [π_b[0,0]*ww[0,0,0]+π_b[0,1]*ww[0,1,0], π_b[0,0]*ww[0,0,1]+π_b[0,1]*ww[0,1,1]],\n", + " ])\n", + " π_b_ = π_b_ / π_b_.sum(axis=1, keepdims=True)\n", + "\n", + " FC_OISs_w.append(\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (N)\n", + " )\n", + " FC_WISs_w.append(\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (np.sum(w* π_e[x,a] / π_b_[x,a]) + np.sum((wc* π_e[x,1-a] / π_b_[x,1-a])[c==1])),\n", + " )\n", + "\n", + " df_bias_var = []\n", + " for name, values in [\n", + " ('$\\hat{v}(\\pi_e)$', Js),\n", + " ('$\\hat{v}(\\pi_b)$', Gs),\n", + " ('OIS', OISs),\n", + " ('WIS', WISs),\n", + " ('C-OIS', FC_OISs_w),\n", + " ('C-WIS', FC_WISs_w),\n", + " ]:\n", + " df_bias_var.append([name, \n", + " np.mean(values), \n", + " np.mean(values - d0@np.sum(π_e*R, axis=1)), \n", + " np.sqrt(np.var(values)), \n", + " np.sqrt(np.mean(np.square(values - d0@np.sum(π_e*R, axis=1))))])\n", + " return pd.DataFrame(df_bias_var, columns=['Approach', 'Mean', 'Bias', 'Std', 'RMSE'])" + ] + }, + { + "cell_type": "markdown", + "id": "3fdea8d1-8479-4592-b2ae-4a65c032daf5", + "metadata": {}, + "source": [ + "# Ideal counterfactual annotations, equal weights" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "0ee15c3b-0180-447d-b9ac-29d91b1e8c2a", + "metadata": {}, + "outputs": [], + "source": [ + "# Counterfactual-augmented IS\n", + "## probability of getting a counterfactual annotation\n", + "Pc = np.array([\n", + " [1., 1.],\n", + "])\n", + "## Weights assigned to factual and counterfactuals\n", + "ww = np.array([[\n", + " [0.5, 0.5],\n", + " [0.5, 0.5],\n", + "]])" + ] + }, + { + "cell_type": "markdown", + "id": "72d31031-7e17-4dfc-b762-7d8d36429480", + "metadata": {}, + "source": [ + "## Rs[0] setting" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "b03a8eda-aadd-4387-89e6-09e4d58d9e8f", + "metadata": {}, + "outputs": [], + "source": [ + "R, sigma = Rs[0], sigmas[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "e1d34e0a-5206-4e5d-ae9c-f56886346f72", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":36: RuntimeWarning: invalid value encountered in double_scalars\n", + " WISs.append(np.sum(rho * r) / np.sum(rho))\n" + ] + } + ], + "source": [ + "df_out_all_0 = []\n", + "for π_b in πs:\n", + " for π_e in πs:\n", + " df_out = single_run()\n", + " df_out_all_0.append(df_out)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "04e99347-6451-4aa6-b7df-407f0907bf37", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
π_e [1. 0.] [0. 1.] [0.5 0.5] [0.1 0.9] [0.8 0.2]
π_b
[1. 0.][0.03 0.5 0.5 ]\n", + " [0.03 0.5 0.5 ][-2. 0. 2. ]\n", + " [ 0.02 0.47 0.47][-0.98 0.25 1.01]\n", + " [ 0.03 0.34 0.35][-1.8 0.05 1.8 ]\n", + " [ 0.02 0.43 0.43][-0.37 0.4 0.55]\n", + " [ 0.03 0.41 0.41]
[0. 1.][-1. 0. 1. ]\n", + " [ 0.02 0.47 0.47][0.03 0.5 0.5 ]\n", + " [0.03 0.5 0.5 ][-0.48 0.25 0.54]\n", + " [ 0.03 0.34 0.35][-0.07 0.45 0.45]\n", + " [ 0.03 0.45 0.45][-0.79 0.1 0.8 ]\n", + " [ 0.02 0.39 0.39]
[0.5 0.5][0.01 1.23 1.23]\n", + " [0.03 0.47 0.47][0.09 2.17 2.17]\n", + " [0.02 0.5 0.5 ][0.05 0.71 0.71]\n", + " [0.03 0.34 0.35][0.08 1.86 1.86]\n", + " [0.02 0.46 0.46][0.02 0.69 0.69]\n", + " [0.03 0.39 0.39]
[0.1 0.9][0.08 3.46 3.46]\n", + " [0.01 0.47 0.47][0.02 0.88 0.88]\n", + " [0.04 0.5 0.5 ][0.05 1.45 1.45]\n", + " [0.03 0.34 0.35][0.03 0.59 0.59]\n", + " [0.03 0.45 0.45][0.07 2.64 2.64]\n", + " [0.02 0.39 0.39]
[0.8 0.2][0.03 0.75 0.75]\n", + " [0.04 0.48 0.48][0.05 4.28 4.28]\n", + " [0.01 0.49 0.49][0.04 1.92 1.92]\n", + " [0.03 0.34 0.35][0.05 3.8 3.81]\n", + " [0.02 0.44 0.44][0.03 0.65 0.65]\n", + " [0.03 0.4 0.4 ]
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df_tmp = pd.DataFrame(index=[str(π)[1:-1] for π in πs], columns=[str(π)[1:-1] for π in πs])\n", + "df_tmp.index.name = 'π_b'\n", + "df_tmp.columns.name = 'π_e'\n", + "for (i, π_b), (j, π_e) in itertools.product(enumerate(πs), enumerate(πs)):\n", + " ix = i*len(πs)+j\n", + " df_tmp.iloc[i,j] = str(df_out_all_0[ix].iloc[[2,4], [2,3,4]].round(2).values)[1:-1]\n", + "display(df_tmp.style.set_properties(**{'white-space': 'pre-wrap'}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48c7d843-695c-488d-9473-4eb8aa13c4d7", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "for (i, π_b), (j, π_e) in itertools.product(enumerate(πs), enumerate(πs)):\n", + " ix = i*len(πs)+j\n", + " values = df_out_all_0[ix].iloc[[2,4], [2,3,4]].round(2).values\n", + " print(\"\"\"\\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} \\\\\\\\ \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}}\"\"\".format(*list(values.ravel())))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "23dd71a3-c736-43aa-8fed-5030ae1798f4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$[\\mask{0.0}{1.0},\\mask{0.0}{0.0}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.03} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 2} & \\mask{0.00}{0} & \\mask{0.00}{2} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.47} & \\mask{0.00}{0.47} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.98} & \\mask{0.00}{0.25} & \\mask{0.00}{1.01} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.34} & \\mask{0.00}{0.35} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 1.8} & \\mask{0.00}{0.05} & \\mask{0.00}{1.8} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.43} & \\mask{0.00}{0.43} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.37} & \\mask{0.00}{0.4} & \\mask{0.00}{0.55} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.41} & \\mask{0.00}{0.41} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.0},\\mask{0.0}{1.0}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 1} & \\mask{0.00}{0} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.47} & \\mask{0.00}{0.47} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.03} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.48} & \\mask{0.00}{0.25} & \\mask{0.00}{0.54} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.34} & \\mask{0.00}{0.35} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.07} & \\mask{0.00}{0.45} & \\mask{0.00}{0.45} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.45} & \\mask{0.00}{0.45} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.79} & \\mask{0.00}{0.1} & \\mask{0.00}{0.8} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.39} & \\mask{0.00}{0.39} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.5},\\mask{0.0}{0.5}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.01} & \\mask{0.00}{1.23} & \\mask{0.00}{1.23} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.47} & \\mask{0.00}{0.47} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.09} & \\mask{0.00}{2.17} & \\mask{0.00}{2.17} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.05} & \\mask{0.00}{0.71} & \\mask{0.00}{0.71} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.34} & \\mask{0.00}{0.35} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.08} & \\mask{0.00}{1.86} & \\mask{0.00}{1.86} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.46} & \\mask{0.00}{0.46} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.02} & \\mask{0.00}{0.69} & \\mask{0.00}{0.69} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.39} & \\mask{0.00}{0.39} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.1},\\mask{0.0}{0.9}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.08} & \\mask{0.00}{3.46} & \\mask{0.00}{3.46} \\\\ \\mask{0.00}{0.01} & \\mask{0.00}{0.47} & \\mask{0.00}{0.47} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.02} & \\mask{0.00}{0.88} & \\mask{0.00}{0.88} \\\\ \\mask{0.00}{0.04} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.05} & \\mask{0.00}{1.45} & \\mask{0.00}{1.45} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.34} & \\mask{0.00}{0.35} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.03} & \\mask{0.00}{0.59} & \\mask{0.00}{0.59} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.45} & \\mask{0.00}{0.45} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.07} & \\mask{0.00}{2.64} & \\mask{0.00}{2.64} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.39} & \\mask{0.00}{0.39} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.8},\\mask{0.0}{0.2}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.03} & \\mask{0.00}{0.75} & \\mask{0.00}{0.75} \\\\ \\mask{0.00}{0.04} & \\mask{0.00}{0.48} & \\mask{0.00}{0.48} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.05} & \\mask{0.00}{4.28} & \\mask{0.00}{4.28} \\\\ \\mask{0.00}{0.01} & \\mask{0.00}{0.49} & \\mask{0.00}{0.49} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.04} & \\mask{0.00}{1.92} & \\mask{0.00}{1.92} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.34} & \\mask{0.00}{0.35} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.05} & \\mask{0.00}{3.8} & \\mask{0.00}{3.81} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.44} & \\mask{0.00}{0.44} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.03} & \\mask{0.00}{0.65} & \\mask{0.00}{0.65} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.4} & \\mask{0.00}{0.4} \\end{matrix}$} \n", + "\\\\[12pt]\n" + ] + } + ], + "source": [ + "for (i, π_b) in enumerate(πs):\n", + " print(\"\"\"$[\\mask{{0.0}}{{{}}},\\mask{{0.0}}{{{}}}]$ \"\"\".format(π_b[0,0], π_b[0,1]))\n", + " for (j, π_e) in enumerate(πs):\n", + " ix = i*len(πs)+j\n", + " values = df_out_all_0[ix].iloc[[2,4], [2,3,4]].round(2).values\n", + " print(\"\"\"& \\\\scalebox{0.8}{$\\\\begin{matrix} \"\"\"\n", + " + \"\"\"\\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} \\\\\\\\ \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}}\"\"\".format(\n", + " *[int(x) if x.is_integer() else x for x in list(values.ravel())]).replace('-', '\\\\shortminus ')\n", + " + \"\"\" \\end{matrix}$} \"\"\")\n", + " print(\"\"\"\\\\\\\\[12pt]\"\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "4a49a117-038a-4596-bbbb-ea9a97db3f15", + "metadata": {}, + "source": [ + "## Rs[1] setting" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "b73eb2b6-0ed8-4b3d-b455-b7f8387600ba", + "metadata": {}, + "outputs": [], + "source": [ + "R, sigma = Rs[1], sigmas[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "bd981a8b-33bb-4097-8f50-0d1ebfb87239", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":36: RuntimeWarning: invalid value encountered in double_scalars\n", + " WISs.append(np.sum(rho * r) / np.sum(rho))\n" + ] + } + ], + "source": [ + "df_out_all_1 = []\n", + "for π_b in πs:\n", + " for π_e in πs:\n", + " df_out = single_run()\n", + " df_out_all_1.append(df_out)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "fb6d567c-1b1e-452a-b25f-026589f52df7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
π_e [1. 0.] [0. 1.] [0.5 0.5] [0.1 0.9] [0.8 0.2]
π_b
[1. 0.][0.03 0.5 0.5 ]\n", + " [0.03 0.5 0.5 ][-1. 0. 1. ]\n", + " [ 0.02 0.47 0.47][-0.48 0.25 0.54]\n", + " [ 0.03 0.34 0.35][-0.9 0.05 0.9 ]\n", + " [ 0.02 0.43 0.43][-0.17 0.4 0.43]\n", + " [ 0.03 0.41 0.41]
[0. 1.][1. 0. 1. ]\n", + " [0.02 0.47 0.47][0.03 0.5 0.5 ]\n", + " [0.03 0.5 0.5 ][0.52 0.25 0.57]\n", + " [0.03 0.34 0.35][0.13 0.45 0.47]\n", + " [0.03 0.45 0.45][0.81 0.1 0.81]\n", + " [0.02 0.39 0.39]
[0.5 0.5][0.06 1.17 1.17]\n", + " [0.03 0.47 0.47][0.06 1.28 1.28]\n", + " [0.02 0.5 0.5 ][0.06 1.12 1.12]\n", + " [0.03 0.34 0.35][0.06 1.23 1.23]\n", + " [0.02 0.46 0.46][0.06 1.13 1.13]\n", + " [0.03 0.39 0.39]
[0.1 0.9][-0.04 3.37 3.37]\n", + " [ 0.01 0.47 0.47][0.03 0.64 0.64]\n", + " [0.04 0.5 0.5 ][-0.01 1.86 1.86]\n", + " [ 0.03 0.34 0.35][0.02 0.8 0.8 ]\n", + " [0.03 0.45 0.45][-0.03 2.76 2.76]\n", + " [ 0.02 0.39 0.39]
[0.8 0.2][0.02 0.74 0.74]\n", + " [0.04 0.48 0.48][0.06 2.42 2.42]\n", + " [0.01 0.49 0.49][0.04 1.46 1.46]\n", + " [0.03 0.34 0.35][0.06 2.22 2.23]\n", + " [0.02 0.44 0.44][0.03 0.95 0.96]\n", + " [0.03 0.4 0.4 ]
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df_tmp = pd.DataFrame(index=[str(π)[1:-1] for π in πs], columns=[str(π)[1:-1] for π in πs])\n", + "df_tmp.index.name = 'π_b'\n", + "df_tmp.columns.name = 'π_e'\n", + "for (i, π_b), (j, π_e) in itertools.product(enumerate(πs), enumerate(πs)):\n", + " ix = i*len(πs)+j\n", + " df_tmp.iloc[i,j] = str(df_out_all_1[ix].iloc[[2,4], [2,3,4]].round(2).values)[1:-1]\n", + "display(df_tmp.style.set_properties(**{'white-space': 'pre-wrap'}))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "61ff3a2a-aeee-4e6f-909d-cfa6431e48bd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$[\\mask{0.0}{1.0},\\mask{0.0}{0.0}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.03} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 1} & \\mask{0.00}{0} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.47} & \\mask{0.00}{0.47} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.48} & \\mask{0.00}{0.25} & \\mask{0.00}{0.54} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.34} & \\mask{0.00}{0.35} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.9} & \\mask{0.00}{0.05} & \\mask{0.00}{0.9} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.43} & \\mask{0.00}{0.43} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.17} & \\mask{0.00}{0.4} & \\mask{0.00}{0.43} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.41} & \\mask{0.00}{0.41} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.0},\\mask{0.0}{1.0}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{1} & \\mask{0.00}{0} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.47} & \\mask{0.00}{0.47} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.03} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.52} & \\mask{0.00}{0.25} & \\mask{0.00}{0.57} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.34} & \\mask{0.00}{0.35} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.13} & \\mask{0.00}{0.45} & \\mask{0.00}{0.47} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.45} & \\mask{0.00}{0.45} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.81} & \\mask{0.00}{0.1} & \\mask{0.00}{0.81} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.39} & \\mask{0.00}{0.39} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.5},\\mask{0.0}{0.5}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.06} & \\mask{0.00}{1.17} & \\mask{0.00}{1.17} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.47} & \\mask{0.00}{0.47} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.06} & \\mask{0.00}{1.28} & \\mask{0.00}{1.28} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.06} & \\mask{0.00}{1.12} & \\mask{0.00}{1.12} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.34} & \\mask{0.00}{0.35} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.06} & \\mask{0.00}{1.23} & \\mask{0.00}{1.23} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.46} & \\mask{0.00}{0.46} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.06} & \\mask{0.00}{1.13} & \\mask{0.00}{1.13} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.39} & \\mask{0.00}{0.39} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.1},\\mask{0.0}{0.9}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.04} & \\mask{0.00}{3.37} & \\mask{0.00}{3.37} \\\\ \\mask{0.00}{0.01} & \\mask{0.00}{0.47} & \\mask{0.00}{0.47} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.03} & \\mask{0.00}{0.64} & \\mask{0.00}{0.64} \\\\ \\mask{0.00}{0.04} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.01} & \\mask{0.00}{1.86} & \\mask{0.00}{1.86} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.34} & \\mask{0.00}{0.35} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.02} & \\mask{0.00}{0.8} & \\mask{0.00}{0.8} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.45} & \\mask{0.00}{0.45} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.03} & \\mask{0.00}{2.76} & \\mask{0.00}{2.76} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.39} & \\mask{0.00}{0.39} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.8},\\mask{0.0}{0.2}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.02} & \\mask{0.00}{0.74} & \\mask{0.00}{0.74} \\\\ \\mask{0.00}{0.04} & \\mask{0.00}{0.48} & \\mask{0.00}{0.48} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.06} & \\mask{0.00}{2.42} & \\mask{0.00}{2.42} \\\\ \\mask{0.00}{0.01} & \\mask{0.00}{0.49} & \\mask{0.00}{0.49} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.04} & \\mask{0.00}{1.46} & \\mask{0.00}{1.46} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.34} & \\mask{0.00}{0.35} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.06} & \\mask{0.00}{2.22} & \\mask{0.00}{2.23} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.44} & \\mask{0.00}{0.44} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.03} & \\mask{0.00}{0.95} & \\mask{0.00}{0.96} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.4} & \\mask{0.00}{0.4} \\end{matrix}$} \n", + "\\\\[12pt]\n" + ] + } + ], + "source": [ + "for (i, π_b) in enumerate(πs):\n", + " print(\"\"\"$[\\mask{{0.0}}{{{}}},\\mask{{0.0}}{{{}}}]$ \"\"\".format(π_b[0,0], π_b[0,1]))\n", + " for (j, π_e) in enumerate(πs):\n", + " ix = i*len(πs)+j\n", + " values = df_out_all_1[ix].iloc[[2,4], [2,3,4]].round(2).values\n", + " print(\"\"\"& \\\\scalebox{0.8}{$\\\\begin{matrix} \"\"\"\n", + " + \"\"\"\\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} \\\\\\\\ \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}}\"\"\".format(\n", + " *[int(x) if x.is_integer() else x for x in list(values.ravel())]).replace('-', '\\\\shortminus ')\n", + " + \"\"\" \\end{matrix}$} \"\"\")\n", + " print(\"\"\"\\\\\\\\[12pt]\"\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "f34399d8-8dcb-4884-888c-aaf437f7f79e", + "metadata": {}, + "source": [ + "## Rs[2] setting" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "392b0aa2-59da-4f15-a3a8-1f3bba5aaa1d", + "metadata": {}, + "outputs": [], + "source": [ + "R, sigma = Rs[2], sigmas[2]" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "5936bc23-075f-437f-b779-b65c06322d22", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":36: RuntimeWarning: invalid value encountered in double_scalars\n", + " WISs.append(np.sum(rho * r) / np.sum(rho))\n" + ] + } + ], + "source": [ + "df_out_all_2 = []\n", + "for π_b in πs:\n", + " for π_e in πs:\n", + " df_out = single_run()\n", + " df_out_all_2.append(df_out)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "bec20438-0e82-48f9-ba54-2bc17b8f7065", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
π_e [1. 0.] [0. 1.] [0.5 0.5] [0.1 0.9] [0.8 0.2]
π_b
[1. 0.][0.03 0.5 0.5 ]\n", + " [0.03 0.5 0.5 ][2. 0. 2. ]\n", + " [0.02 0.47 0.47][1.02 0.25 1.05]\n", + " [0.03 0.34 0.35][1.8 0.05 1.8 ]\n", + " [0.02 0.43 0.43][0.43 0.4 0.58]\n", + " [0.03 0.41 0.41]
[0. 1.][1. 0. 1. ]\n", + " [0.02 0.47 0.47][0.03 0.5 0.5 ]\n", + " [0.03 0.5 0.5 ][0.52 0.25 0.57]\n", + " [0.03 0.34 0.35][0.13 0.45 0.47]\n", + " [0.03 0.45 0.45][0.81 0.1 0.81]\n", + " [0.02 0.39 0.39]
[0.5 0.5][0.06 1.17 1.17]\n", + " [0.03 0.47 0.47][-0.01 2.1 2.1 ]\n", + " [ 0.02 0.5 0.5 ][0.02 0.7 0.71]\n", + " [0.03 0.34 0.35][-0. 1.8 1.8 ]\n", + " [ 0.02 0.46 0.46][0.04 0.67 0.67]\n", + " [0.03 0.39 0.39]
[0.1 0.9][-0.04 3.37 3.37]\n", + " [ 0.01 0.47 0.47][0.05 0.86 0.86]\n", + " [0.04 0.5 0.5 ][0. 1.41 1.41]\n", + " [0.03 0.34 0.35][0.04 0.58 0.58]\n", + " [0.03 0.45 0.45][-0.02 2.57 2.57]\n", + " [ 0.02 0.39 0.39]
[0.8 0.2][0.02 0.74 0.74]\n", + " [0.04 0.48 0.48][0.09 4.02 4.02]\n", + " [0.01 0.49 0.49][0.06 1.8 1.8 ]\n", + " [0.03 0.34 0.35][0.08 3.57 3.57]\n", + " [0.02 0.44 0.44][0.04 0.63 0.63]\n", + " [0.03 0.4 0.4 ]
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df_tmp = pd.DataFrame(index=[str(π)[1:-1] for π in πs], columns=[str(π)[1:-1] for π in πs])\n", + "df_tmp.index.name = 'π_b'\n", + "df_tmp.columns.name = 'π_e'\n", + "for (i, π_b), (j, π_e) in itertools.product(enumerate(πs), enumerate(πs)):\n", + " ix = i*len(πs)+j\n", + " df_tmp.iloc[i,j] = str(df_out_all_2[ix].iloc[[2,4], [2,3,4]].round(2).values)[1:-1]\n", + "display(df_tmp.style.set_properties(**{'white-space': 'pre-wrap'}))" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "3aa514d1-939e-4adb-94c1-8732c1875c7b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$[\\mask{0.0}{1.0},\\mask{0.0}{0.0}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.03} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{2} & \\mask{0.00}{0} & \\mask{0.00}{2} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.47} & \\mask{0.00}{0.47} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{1.02} & \\mask{0.00}{0.25} & \\mask{0.00}{1.05} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.34} & \\mask{0.00}{0.35} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{1.8} & \\mask{0.00}{0.05} & \\mask{0.00}{1.8} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.43} & \\mask{0.00}{0.43} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.43} & \\mask{0.00}{0.4} & \\mask{0.00}{0.58} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.41} & \\mask{0.00}{0.41} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.0},\\mask{0.0}{1.0}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{1} & \\mask{0.00}{0} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.47} & \\mask{0.00}{0.47} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.03} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.52} & \\mask{0.00}{0.25} & \\mask{0.00}{0.57} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.34} & \\mask{0.00}{0.35} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.13} & \\mask{0.00}{0.45} & \\mask{0.00}{0.47} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.45} & \\mask{0.00}{0.45} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.81} & \\mask{0.00}{0.1} & \\mask{0.00}{0.81} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.39} & \\mask{0.00}{0.39} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.5},\\mask{0.0}{0.5}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.06} & \\mask{0.00}{1.17} & \\mask{0.00}{1.17} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.47} & \\mask{0.00}{0.47} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.01} & \\mask{0.00}{2.1} & \\mask{0.00}{2.1} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.02} & \\mask{0.00}{0.7} & \\mask{0.00}{0.71} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.34} & \\mask{0.00}{0.35} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1.8} & \\mask{0.00}{1.8} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.46} & \\mask{0.00}{0.46} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.04} & \\mask{0.00}{0.67} & \\mask{0.00}{0.67} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.39} & \\mask{0.00}{0.39} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.1},\\mask{0.0}{0.9}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.04} & \\mask{0.00}{3.37} & \\mask{0.00}{3.37} \\\\ \\mask{0.00}{0.01} & \\mask{0.00}{0.47} & \\mask{0.00}{0.47} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.05} & \\mask{0.00}{0.86} & \\mask{0.00}{0.86} \\\\ \\mask{0.00}{0.04} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1.41} & \\mask{0.00}{1.41} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.34} & \\mask{0.00}{0.35} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.04} & \\mask{0.00}{0.58} & \\mask{0.00}{0.58} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.45} & \\mask{0.00}{0.45} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.02} & \\mask{0.00}{2.57} & \\mask{0.00}{2.57} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.39} & \\mask{0.00}{0.39} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.8},\\mask{0.0}{0.2}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.02} & \\mask{0.00}{0.74} & \\mask{0.00}{0.74} \\\\ \\mask{0.00}{0.04} & \\mask{0.00}{0.48} & \\mask{0.00}{0.48} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.09} & \\mask{0.00}{4.02} & \\mask{0.00}{4.02} \\\\ \\mask{0.00}{0.01} & \\mask{0.00}{0.49} & \\mask{0.00}{0.49} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.06} & \\mask{0.00}{1.8} & \\mask{0.00}{1.8} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.34} & \\mask{0.00}{0.35} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.08} & \\mask{0.00}{3.57} & \\mask{0.00}{3.57} \\\\ \\mask{0.00}{0.02} & \\mask{0.00}{0.44} & \\mask{0.00}{0.44} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.04} & \\mask{0.00}{0.63} & \\mask{0.00}{0.63} \\\\ \\mask{0.00}{0.03} & \\mask{0.00}{0.4} & \\mask{0.00}{0.4} \\end{matrix}$} \n", + "\\\\[12pt]\n" + ] + } + ], + "source": [ + "for (i, π_b) in enumerate(πs):\n", + " print(\"\"\"$[\\mask{{0.0}}{{{}}},\\mask{{0.0}}{{{}}}]$ \"\"\".format(π_b[0,0], π_b[0,1]))\n", + " for (j, π_e) in enumerate(πs):\n", + " ix = i*len(πs)+j\n", + " values = df_out_all_2[ix].iloc[[2,4], [2,3,4]].round(2).values\n", + " print(\"\"\"& \\\\scalebox{0.8}{$\\\\begin{matrix} \"\"\"\n", + " + \"\"\"\\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} \\\\\\\\ \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}}\"\"\".format(\n", + " *[int(x) if x.is_integer() else x for x in list(values.ravel())]).replace('-', '\\\\shortminus ')\n", + " + \"\"\" \\end{matrix}$} \"\"\")\n", + " print(\"\"\"\\\\\\\\[12pt]\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0fa9eb3-bec0-4429-96ae-075c203d883f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "RL_venv", + "language": "python", + "name": "rl_venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/synthetic/bandit_compare-2state.ipynb b/synthetic/bandit_compare-2state.ipynb new file mode 100644 index 0000000..811d362 --- /dev/null +++ b/synthetic/bandit_compare-2state.ipynb @@ -0,0 +1,1077 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "a6eac7c9-b2bb-4995-bf22-a9d66093f9d6", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from IPython.display import display\n", + "import matplotlib.pyplot as plt\n", + "%config InlineBackend.figure_formats = ['svg']\n", + "import matplotlib\n", + "matplotlib.rcParams['text.usetex'] = True\n", + "matplotlib.rcParams['font.sans-serif'] = ['FreeSans']\n", + "import seaborn as sns\n", + "import itertools\n", + "from tqdm import tqdm\n", + "import joblib" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "16192b2c-09d3-49b7-98d1-edc8c5bbbd22", + "metadata": {}, + "outputs": [], + "source": [ + "d0 = np.array([0.5, 0.5])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "3b17c719-1120-434f-9785-e5d3583909e4", + "metadata": {}, + "outputs": [], + "source": [ + "Rs = [\n", + " np.array([[1., 2.], [0., 0.]]),\n", + " np.array([[-1., 1.], [0., 0.]]),\n", + " np.array([[-1., -2.], [0., 0.]]),\n", + "]\n", + "sigmas = [\n", + " np.array([[0.5, 0.5], [0.5, 0.5]]),\n", + " np.array([[0.5, 0.5], [0.5, 0.5]]),\n", + " np.array([[0.5, 0.5], [0.5, 0.5]]),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a2617940-a7a0-4de0-b96c-0c4d87b4365c", + "metadata": {}, + "outputs": [], + "source": [ + "πs = [\n", + " np.array([[1., 0.], [1., 0.]]),\n", + " np.array([[0., 1.], [1., 0.]]),\n", + " np.array([[0.5, 0.5], [1., 0.]]),\n", + " np.array([[0.1, 0.9], [1., 0.]]),\n", + " np.array([[0.8, 0.2], [1., 0.]]),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3b57307b-c5c4-4b26-be34-5e8d13a2ea20", + "metadata": {}, + "outputs": [], + "source": [ + "use_πD = False" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "815dabf9-cdff-4d7a-bac4-889d6b046442", + "metadata": {}, + "outputs": [], + "source": [ + "N = 1\n", + "runs = 1000" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a82feb16-b247-409b-b556-21e1bab36602", + "metadata": {}, + "outputs": [], + "source": [ + "def single_exp_setting(π_b, π_e):\n", + " np.random.seed(42)\n", + "\n", + " # True value of π_e\n", + " Js = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " x = rng.choice(len(d0), size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_e[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " J = np.sum(r) / N\n", + " Js.append(J)\n", + "\n", + " # Standard IS\n", + " Gs = []\n", + " OISs = []\n", + " WISs = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " x = rng.choice(len(d0), size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_b[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " G = np.sum(r) / N\n", + " Gs.append(G)\n", + "\n", + " if use_πD:\n", + " assert False\n", + " else:\n", + " π_b_ = π_b\n", + "\n", + " rho = π_e[x,a] / π_b_[x,a]\n", + " OISs.append(np.sum(rho * r) / N)\n", + " WISs.append(np.sum(rho * r) / np.sum(rho))\n", + "\n", + " \n", + " # Collect data using π_b - naive approach\n", + " Naive_OISs = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " rng_c = np.random.default_rng(seed=100000+seed)\n", + " x = rng.choice(len(d0), size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_b[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " rho = π_e[x,a] / π_b[x,a]\n", + "\n", + " # counterfactual flag\n", + " c = np.array([rng_c.choice(2, p=[1-Pc[xi,ai], Pc[xi,ai]]) for xi,ai in zip(x,a)])\n", + "\n", + " # counterfactual reward\n", + " rc = np.array([rng_c.normal(R[xi,1-ai], sigma[xi,1-ai]) for xi,ai in zip(x,a)])\n", + " rc[c==0] = np.nan\n", + "\n", + " # trajectory-wise weight\n", + " w = np.ones(N)\n", + " w[c==1] = ww_naive[x[c==1], a[c==1], a[c==1]]\n", + " wc = np.zeros(N)\n", + " wc[c==1] = ww_naive[x[c==1], a[c==1], 1-a[c==1]]\n", + "\n", + " if use_πD:\n", + " # augmented behavior policy\n", + " assert False\n", + " else:\n", + " # augmented behavior policy\n", + " π_b_ = np.array([\n", + " [π_b[0,0]*ww_naive[0,0,0]+π_b[0,1]*ww_naive[0,1,0], π_b[0,0]*ww_naive[0,0,1]+π_b[0,1]*ww_naive[0,1,1]],\n", + " [π_b[1,0]*ww_naive[1,0,0]+π_b[1,1]*ww_naive[1,1,0], π_b[1,0]*ww_naive[1,0,1]+π_b[1,1]*ww_naive[1,1,1]],\n", + " ])\n", + " π_b_ = π_b_ / π_b_.sum(axis=1, keepdims=True)\n", + "\n", + " # Naive_WISs.append(\n", + " # (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (np.sum(w* π_e[x,a] / π_b_[x,a]) + np.sum((wc* π_e[x,1-a] / π_b_[x,1-a])[c==1]))\n", + " # )\n", + " ## Add factual and counterfactual separately\n", + " Naive_OISs.append(np.sum(π_e[x,a] / π_b_[x,a] * r) / N)\n", + " if np.sum(c) > 0:\n", + " Naive_OISs.append(np.sum(π_e[x,1-a] / π_b_[x,1-a] * rc) / np.sum(c))\n", + "\n", + "\n", + " # Collect data using π_b - combining counterfactuals with factuals\n", + " FC_OISs_w = []\n", + " FC_WISs_w = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " rng_c = np.random.default_rng(seed=100000+seed)\n", + " x = rng.choice(len(d0), size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_b[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " rho = π_e[x,a] / π_b[x,a]\n", + "\n", + " # counterfactual flag\n", + " c = np.array([rng_c.choice(2, p=[1-Pc[xi,ai], Pc[xi,ai]]) for xi,ai in zip(x,a)])\n", + "\n", + " # counterfactual reward\n", + " rc = np.array([rng_c.normal(R[xi,1-ai], sigma[xi,1-ai]) for xi,ai in zip(x,a)])\n", + " rc[c==0] = np.nan\n", + "\n", + " # trajectory-wise weight\n", + " w = np.ones(N)\n", + " w[c==1] = ww[x[c==1], a[c==1], a[c==1]]\n", + " wc = np.zeros(N)\n", + " wc[c==1] = ww[x[c==1], a[c==1], 1-a[c==1]]\n", + "\n", + " if use_πD:\n", + " # augmented behavior policy\n", + " assert False\n", + " else:\n", + " # augmented behavior policy\n", + " π_b_ = np.array([\n", + " [π_b[0,0]*ww[0,0,0]+π_b[0,1]*ww[0,1,0], π_b[0,0]*ww[0,0,1]+π_b[0,1]*ww[0,1,1]],\n", + " [π_b[1,0]*ww[1,0,0]+π_b[1,1]*ww[1,1,0], π_b[1,0]*ww[1,0,1]+π_b[1,1]*ww[1,1,1]],\n", + " ])\n", + " π_b_ = π_b_ / π_b_.sum(axis=1, keepdims=True)\n", + "\n", + " FC_OISs_w.append(\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (N)\n", + " )\n", + " FC_WISs_w.append(\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (np.sum(w* π_e[x,a] / π_b_[x,a]) + np.sum((wc* π_e[x,1-a] / π_b_[x,1-a])[c==1])),\n", + " )\n", + "\n", + " df_bias_var = []\n", + " for name, values in [\n", + " ('$\\hat{v}(\\pi_e)$', Js),\n", + " ('$\\hat{v}(\\pi_b)$', Gs),\n", + " ('OIS', OISs),\n", + " ('WIS', WISs),\n", + " ('C-OIS', FC_OISs_w),\n", + " ('C-WIS', FC_WISs_w),\n", + " ('Naive-OIS', Naive_OISs),\n", + " ]:\n", + " df_bias_var.append([name, \n", + " np.mean(values), \n", + " np.mean(values - d0@np.sum(π_e*R, axis=1)), \n", + " np.sqrt(np.var(values)), \n", + " np.sqrt(np.mean(np.square(values - d0@np.sum(π_e*R, axis=1))))])\n", + " return pd.DataFrame(df_bias_var, columns=['Approach', 'Mean', 'Bias', 'Std', 'RMSE'])" + ] + }, + { + "cell_type": "markdown", + "id": "3fdea8d1-8479-4592-b2ae-4a65c032daf5", + "metadata": {}, + "source": [ + "# Ideal counterfactual annotations, equal weights" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "0ee15c3b-0180-447d-b9ac-29d91b1e8c2a", + "metadata": {}, + "outputs": [], + "source": [ + "# Counterfactual-augmented IS\n", + "## probability of getting a counterfactual annotation\n", + "Pc = np.array([\n", + " [1., 1.],\n", + " [0., 0.],\n", + "])\n", + "## Weights assigned to factual and counterfactuals\n", + "ww = np.array([\n", + " [[0.5, 0.5], [0.5, 0.5]],\n", + " [[1, 0], [0, 1]],\n", + "])\n", + "\n", + "ww_naive = np.array([\n", + " [[1, 1], [1, 1]],\n", + " [[1, 0], [0, 1]],\n", + "])" + ] + }, + { + "cell_type": "markdown", + "id": "72d31031-7e17-4dfc-b762-7d8d36429480", + "metadata": {}, + "source": [ + "## Rs[0] setting" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "b03a8eda-aadd-4387-89e6-09e4d58d9e8f", + "metadata": {}, + "outputs": [], + "source": [ + "R, sigma = Rs[0], sigmas[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "e1d34e0a-5206-4e5d-ae9c-f56886346f72", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/25 [00:00:115: RuntimeWarning: invalid value encountered in true_divide\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (N)\n", + ":118: RuntimeWarning: invalid value encountered in true_divide\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (np.sum(w* π_e[x,a] / π_b_[x,a]) + np.sum((wc* π_e[x,1-a] / π_b_[x,1-a])[c==1])),\n", + " 4%|▍ | 1/25 [00:07<02:51, 7.13s/it]:33: RuntimeWarning: invalid value encountered in double_scalars\n", + " WISs.append(np.sum(rho * r) / np.sum(rho))\n", + "100%|██████████| 25/25 [02:02<00:00, 4.89s/it]\n" + ] + } + ], + "source": [ + "df_out_all_0 = []\n", + "for π_b, π_e in tqdm(list(itertools.product(πs, πs))):\n", + " df_out = single_exp_setting(π_b, π_e)\n", + " df_out_all_0.append(df_out)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "04e99347-6451-4aa6-b7df-407f0907bf37", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
π_e [1. 0.]\n", + " [1. 0.] [0. 1.]\n", + " [1. 0.] [0.5 0.5]\n", + " [1. 0. ] [0.1 0.9]\n", + " [1. 0. ] [0.8 0.2]\n", + " [1. 0. ]
π_b
[1. 0.]\n", + " [1. 0.][0. 0.7 0.7]\n", + " [0.2 1.1 1.2]\n", + " [0. 0.7 0.7][-1. 0.4 1. ]\n", + " [ 0.3 2. 2. ]\n", + " [ 0. 1.1 1.1][-0.5 0.5 0.7]\n", + " [ 0.3 0.9 1. ]\n", + " [ 0. 0.9 0.9][-0.9 0.4 1. ]\n", + " [ 0.3 1.8 1.8]\n", + " [ 0. 1.1 1.1][-0.2 0.6 0.6]\n", + " [ 0.2 0.9 0.9]\n", + " [ 0. 0.7 0.7]
[0. 1.]\n", + " [1. 0.][-0.5 0.4 0.6]\n", + " [ 0.2 1.1 1.1]\n", + " [ 0. 0.7 0.7][0. 1.1 1.1]\n", + " [0.3 2. 2. ]\n", + " [0. 1.1 1.1][-0.2 0.6 0.7]\n", + " [ 0.3 0.9 1. ]\n", + " [ 0. 0.9 0.9][-0. 1. 1. ]\n", + " [ 0.3 1.8 1.8]\n", + " [ 0. 1.1 1.1][-0.4 0.4 0.6]\n", + " [ 0.2 0.8 0.9]\n", + " [ 0. 0.7 0.7]
[0.5 0.5]\n", + " [1. 0. ][0. 1. 1. ]\n", + " [0.2 1.1 1.1]\n", + " [0. 0.7 0.7][-0. 1.8 1.8]\n", + " [ 0.3 2. 2. ]\n", + " [ 0. 1.1 1.1][0. 1. 1. ]\n", + " [0.3 0.9 1. ]\n", + " [0. 0.9 0.9][-0. 1.6 1.6]\n", + " [ 0.3 1.8 1.8]\n", + " [ 0. 1.1 1.1][0. 0.8 0.8]\n", + " [0.2 0.8 0.9]\n", + " [0. 0.7 0.7]
[0.1 0.9]\n", + " [1. 0. ][0.1 2.6 2.6]\n", + " [0.2 1.1 1.1]\n", + " [0. 0.7 0.7][-0. 1.2 1.2]\n", + " [ 0.3 2. 2. ]\n", + " [ 0. 1.1 1.1][0. 1.3 1.3]\n", + " [0.3 0.9 1. ]\n", + " [0. 0.9 0.9][0. 1.1 1.1]\n", + " [0.3 1.8 1.8]\n", + " [0. 1.1 1.1][0.1 2. 2. ]\n", + " [0.2 0.8 0.9]\n", + " [0. 0.7 0.7]
[0.8 0.2]\n", + " [1. 0. ][0. 0.8 0.8]\n", + " [0.2 1.1 1.2]\n", + " [0. 0.7 0.7][0.1 3.2 3.2]\n", + " [0.3 2. 2. ]\n", + " [0. 1.1 1.1][0. 1.6 1.6]\n", + " [0.3 0.9 1. ]\n", + " [0. 0.9 0.9][0. 2.9 2.9]\n", + " [0.3 1.7 1.8]\n", + " [0. 1.1 1.1][0. 0.8 0.8]\n", + " [0.2 0.8 0.9]\n", + " [0. 0.7 0.7]
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df_tmp = pd.DataFrame(index=[str(π)[1:-1] for π in πs], columns=[str(π)[1:-1] for π in πs])\n", + "df_tmp.index.name = 'π_b'\n", + "df_tmp.columns.name = 'π_e'\n", + "for (i, π_b), (j, π_e) in itertools.product(enumerate(πs), enumerate(πs)):\n", + " ix = i*len(πs)+j\n", + " df_tmp.iloc[i,j] = str(df_out_all_0[ix].iloc[[2,6,4], [2,3,4]].round(1).values)[1:-1]\n", + "display(df_tmp.style.set_properties(**{'white-space': 'pre-wrap'}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46a3ac26-2b7d-4385-a9cb-eb96f3ad1115", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "for (i, π_b) in enumerate(πs):\n", + " print(\"\"\"$[\\mask{{0.0}}{{{}}},\\mask{{0.0}}{{{}}}]$ \"\"\".format(π_b[0,0], π_b[0,1]))\n", + " for (j, π_e) in enumerate(πs):\n", + " ix = i*len(πs)+j\n", + " values = df_out_all_0[ix].iloc[[2,6,4], [2,3,4]].round(2).values\n", + " print(\"\"\"& \\\\scalebox{0.8}{$\\\\begin{matrix} \"\"\"\n", + " + \"\"\"\\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} \\\\\\\\ \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} \\\\\\\\ \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}}\"\"\".format(*list(values.ravel())).replace('-', '\\\\shortminus ')\n", + " + \"\"\" \\end{matrix}$} \"\"\")\n", + " print(\"\"\"\\\\\\\\[12pt]\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "27e8abe2-b622-4f5a-9a42-7690c0ea3372", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$[\\mask{0.0}{1.0},\\mask{0.0}{0.0}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{1.1} & \\mask{0.00}{1.2} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 1} & \\mask{0.00}{0.4} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0.3} & \\mask{0.00}{2} & \\mask{0.00}{2} \\\\ \\mask{0.00}{0} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.5} & \\mask{0.00}{0.5} & \\mask{0.00}{0.7} \\\\ \\mask{0.00}{0.3} & \\mask{0.00}{0.9} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.9} & \\mask{0.00}{0.4} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0.3} & \\mask{0.00}{1.8} & \\mask{0.00}{1.8} \\\\ \\mask{0.00}{0} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.2} & \\mask{0.00}{0.6} & \\mask{0.00}{0.6} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.0},\\mask{0.0}{1.0}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.5} & \\mask{0.00}{0.4} & \\mask{0.00}{0.6} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0.3} & \\mask{0.00}{2} & \\mask{0.00}{2} \\\\ \\mask{0.00}{0} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.2} & \\mask{0.00}{0.6} & \\mask{0.00}{0.7} \\\\ \\mask{0.00}{0.3} & \\mask{0.00}{0.9} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0.3} & \\mask{0.00}{1.8} & \\mask{0.00}{1.8} \\\\ \\mask{0.00}{0} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.4} & \\mask{0.00}{0.4} & \\mask{0.00}{0.6} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{0.8} & \\mask{0.00}{0.9} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.5},\\mask{0.0}{0.5}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1.8} & \\mask{0.00}{1.8} \\\\ \\mask{0.00}{0.3} & \\mask{0.00}{2} & \\mask{0.00}{2} \\\\ \\mask{0.00}{0} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0.3} & \\mask{0.00}{0.9} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1.6} & \\mask{0.00}{1.6} \\\\ \\mask{0.00}{0.3} & \\mask{0.00}{1.8} & \\mask{0.00}{1.8} \\\\ \\mask{0.00}{0} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{0.8} & \\mask{0.00}{0.8} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{0.8} & \\mask{0.00}{0.9} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.1},\\mask{0.0}{0.9}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.1} & \\mask{0.00}{2.6} & \\mask{0.00}{2.6} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1.2} & \\mask{0.00}{1.2} \\\\ \\mask{0.00}{0.3} & \\mask{0.00}{2} & \\mask{0.00}{2} \\\\ \\mask{0.00}{0} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1.3} & \\mask{0.00}{1.3} \\\\ \\mask{0.00}{0.3} & \\mask{0.00}{0.9} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0.3} & \\mask{0.00}{1.8} & \\mask{0.00}{1.8} \\\\ \\mask{0.00}{0} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.1} & \\mask{0.00}{2} & \\mask{0.00}{2} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{0.8} & \\mask{0.00}{0.9} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.8},\\mask{0.0}{0.2}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{0.8} & \\mask{0.00}{0.8} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{1.1} & \\mask{0.00}{1.2} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.1} & \\mask{0.00}{3.2} & \\mask{0.00}{3.2} \\\\ \\mask{0.00}{0.3} & \\mask{0.00}{2} & \\mask{0.00}{2} \\\\ \\mask{0.00}{0} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1.6} & \\mask{0.00}{1.6} \\\\ \\mask{0.00}{0.3} & \\mask{0.00}{0.9} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{2.9} & \\mask{0.00}{2.9} \\\\ \\mask{0.00}{0.3} & \\mask{0.00}{1.7} & \\mask{0.00}{1.8} \\\\ \\mask{0.00}{0} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{0.8} & \\mask{0.00}{0.8} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{0.8} & \\mask{0.00}{0.9} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "\\\\[12pt]\n" + ] + } + ], + "source": [ + "for (i, π_b) in enumerate(πs):\n", + " print(\"\"\"$[\\mask{{0.0}}{{{}}},\\mask{{0.0}}{{{}}}]$ \"\"\".format(π_b[0,0], π_b[0,1]))\n", + " for (j, π_e) in enumerate(πs):\n", + " ix = i*len(πs)+j\n", + " values = df_out_all_0[ix].iloc[[2,6,4], [2,3,4]].round(1).values\n", + " print(\"\"\"& \\\\scalebox{0.8}{$\\\\begin{matrix} \"\"\"\n", + " + \"\"\"\\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} \\\\\\\\ \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} \\\\\\\\ \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}}\"\"\".format(\n", + " *[int(x) if x.is_integer() else x for x in list(values.ravel())]).replace('-', '\\\\shortminus ')\n", + " + \"\"\" \\end{matrix}$} \"\"\")\n", + " print(\"\"\"\\\\\\\\[12pt]\"\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "c2bcac69-d99d-4ba8-9029-65458f7efe6e", + "metadata": {}, + "source": [ + "## Rs[1] setting" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "6444e6eb-2fe8-4386-b418-6e984b474356", + "metadata": {}, + "outputs": [], + "source": [ + "R, sigma = Rs[1], sigmas[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "879f3f69-e16d-42e0-a8c8-d33f99e6e6a4", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/25 [00:00:115: RuntimeWarning: invalid value encountered in true_divide\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (N)\n", + ":118: RuntimeWarning: invalid value encountered in true_divide\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (np.sum(w* π_e[x,a] / π_b_[x,a]) + np.sum((wc* π_e[x,1-a] / π_b_[x,1-a])[c==1])),\n", + " 4%|▍ | 1/25 [00:04<01:41, 4.24s/it]:33: RuntimeWarning: invalid value encountered in double_scalars\n", + " WISs.append(np.sum(rho * r) / np.sum(rho))\n", + "100%|██████████| 25/25 [02:03<00:00, 4.93s/it]\n" + ] + } + ], + "source": [ + "df_out_all_1 = []\n", + "for π_b, π_e in tqdm(list(itertools.product(πs, πs))):\n", + " df_out = single_exp_setting(π_b, π_e)\n", + " df_out_all_1.append(df_out)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "3644f5a4-d3be-4c44-bdee-92f5728b6c6e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
π_e [1. 0.]\n", + " [1. 0.] [0. 1.]\n", + " [1. 0.] [0.5 0.5]\n", + " [1. 0. ] [0.1 0.9]\n", + " [1. 0. ] [0.8 0.2]\n", + " [1. 0. ]
π_b
[1. 0.]\n", + " [1. 0.][ 0. 0.7 0.7]\n", + " [-0.1 1.1 1.1]\n", + " [ 0. 0.7 0.7][-0.5 0.4 0.6]\n", + " [ 0.2 1.1 1.1]\n", + " [ 0. 0.7 0.7][-0.2 0.5 0.5]\n", + " [ 0. 0.9 0.9]\n", + " [ 0. 0.4 0.4][-0.4 0.4 0.6]\n", + " [ 0.2 1.1 1.1]\n", + " [ 0. 0.6 0.6][-0.1 0.6 0.6]\n", + " [-0.1 1. 1. ]\n", + " [ 0. 0.6 0.6]
[0. 1.]\n", + " [1. 0.][ 0.5 0.4 0.6]\n", + " [-0.1 1.1 1.1]\n", + " [ 0. 0.7 0.7][0. 0.7 0.7]\n", + " [0.2 1.1 1.2]\n", + " [0. 0.7 0.7][0.3 0.5 0.5]\n", + " [0. 0.9 0.9]\n", + " [0. 0.4 0.4][0.1 0.6 0.7]\n", + " [0.2 1.1 1.1]\n", + " [0. 0.6 0.6][ 0.4 0.4 0.6]\n", + " [-0.1 1. 1. ]\n", + " [ 0. 0.6 0.6]
[0.5 0.5]\n", + " [1. 0. ][ 0. 1.1 1.1]\n", + " [-0.1 1.1 1.1]\n", + " [ 0. 0.7 0.7][0. 1.1 1.1]\n", + " [0.2 1.2 1.2]\n", + " [0. 0.7 0.7][0. 0.9 0.9]\n", + " [0. 1. 1. ]\n", + " [0. 0.4 0.4][0. 1. 1. ]\n", + " [0.2 1.1 1.1]\n", + " [0. 0.6 0.6][ 0. 0.9 0.9]\n", + " [-0.1 1. 1. ]\n", + " [ 0. 0.6 0.6]
[0.1 0.9]\n", + " [1. 0. ][-0. 2.4 2.4]\n", + " [-0.1 1.1 1.1]\n", + " [ 0. 0.7 0.7][0. 0.7 0.7]\n", + " [0.2 1.1 1.2]\n", + " [0. 0.7 0.7][0. 1.4 1.4]\n", + " [0. 0.9 0.9]\n", + " [0. 0.4 0.4][0. 0.8 0.8]\n", + " [0.2 1.1 1.1]\n", + " [0. 0.6 0.6][-0. 2. 2. ]\n", + " [-0.1 1. 1. ]\n", + " [ 0. 0.6 0.6]
[0.8 0.2]\n", + " [1. 0. ][ 0. 0.8 0.8]\n", + " [-0.1 1.1 1.1]\n", + " [ 0.1 0.7 0.7][0.1 1.8 1.8]\n", + " [0.2 1.1 1.2]\n", + " [0. 0.7 0.7][0. 1.1 1.1]\n", + " [0. 0.9 0.9]\n", + " [0. 0.4 0.4][0.1 1.7 1.7]\n", + " [0.1 1.1 1.1]\n", + " [0. 0.6 0.6][ 0. 0.8 0.8]\n", + " [-0.1 1. 1. ]\n", + " [ 0. 0.5 0.5]
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df_tmp = pd.DataFrame(index=[str(π)[1:-1] for π in πs], columns=[str(π)[1:-1] for π in πs])\n", + "df_tmp.index.name = 'π_b'\n", + "df_tmp.columns.name = 'π_e'\n", + "for (i, π_b), (j, π_e) in itertools.product(enumerate(πs), enumerate(πs)):\n", + " ix = i*len(πs)+j\n", + " df_tmp.iloc[i,j] = str(df_out_all_1[ix].iloc[[2,6,4], [2,3,4]].round(1).values)[1:-1]\n", + "display(df_tmp.style.set_properties(**{'white-space': 'pre-wrap'}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e7abac0-37db-4c9c-ad22-51e24a4a669c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "for (i, π_b) in enumerate(πs):\n", + " print(\"\"\"$[\\mask{{0.0}}{{{}}},\\mask{{0.0}}{{{}}}]$ \"\"\".format(π_b[0,0], π_b[0,1]))\n", + " for (j, π_e) in enumerate(πs):\n", + " ix = i*len(πs)+j\n", + " values = df_out_all_1[ix].iloc[[2,6,4], [2,3,4]].round(2).values\n", + " print(\"\"\"& \\\\scalebox{0.8}{$\\\\begin{matrix} \"\"\"\n", + " + \"\"\"\\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} \\\\\\\\ \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} \\\\\\\\ \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}}\"\"\".format(*list(values.ravel())).replace('-', '\\\\shortminus ')\n", + " + \"\"\" \\end{matrix}$} \"\"\")\n", + " print(\"\"\"\\\\\\\\[12pt]\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "f608f470-d130-4270-96bc-9c96912985d5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$[\\mask{0.0}{1.0},\\mask{0.0}{0.0}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\\\ \\mask{0.00}{\\shortminus 0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.5} & \\mask{0.00}{0.4} & \\mask{0.00}{0.6} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.2} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.4} & \\mask{0.00}{0.4} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.4} & \\mask{0.00}{0.4} & \\mask{0.00}{0.6} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.6} & \\mask{0.00}{0.6} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{\\shortminus 0.1} & \\mask{0.00}{0.6} & \\mask{0.00}{0.6} \\\\ \\mask{0.00}{\\shortminus 0.1} & \\mask{0.00}{1} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.6} & \\mask{0.00}{0.6} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.0},\\mask{0.0}{1.0}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.5} & \\mask{0.00}{0.4} & \\mask{0.00}{0.6} \\\\ \\mask{0.00}{\\shortminus 0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{1.1} & \\mask{0.00}{1.2} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.3} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.4} & \\mask{0.00}{0.4} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.1} & \\mask{0.00}{0.6} & \\mask{0.00}{0.7} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.6} & \\mask{0.00}{0.6} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.4} & \\mask{0.00}{0.4} & \\mask{0.00}{0.6} \\\\ \\mask{0.00}{\\shortminus 0.1} & \\mask{0.00}{1} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.6} & \\mask{0.00}{0.6} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.5},\\mask{0.0}{0.5}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{\\shortminus 0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{1.2} & \\mask{0.00}{1.2} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\\\ \\mask{0.00}{0} & \\mask{0.00}{1} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.4} & \\mask{0.00}{0.4} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.6} & \\mask{0.00}{0.6} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\\\ \\mask{0.00}{\\shortminus 0.1} & \\mask{0.00}{1} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.6} & \\mask{0.00}{0.6} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.1},\\mask{0.0}{0.9}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{2.4} & \\mask{0.00}{2.4} \\\\ \\mask{0.00}{\\shortminus 0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{1.1} & \\mask{0.00}{1.2} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1.4} & \\mask{0.00}{1.4} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.4} & \\mask{0.00}{0.4} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{0.8} & \\mask{0.00}{0.8} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.6} & \\mask{0.00}{0.6} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{2} & \\mask{0.00}{2} \\\\ \\mask{0.00}{\\shortminus 0.1} & \\mask{0.00}{1} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.6} & \\mask{0.00}{0.6} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.8},\\mask{0.0}{0.2}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{0.8} & \\mask{0.00}{0.8} \\\\ \\mask{0.00}{\\shortminus 0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0.1} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.1} & \\mask{0.00}{1.8} & \\mask{0.00}{1.8} \\\\ \\mask{0.00}{0.2} & \\mask{0.00}{1.1} & \\mask{0.00}{1.2} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.4} & \\mask{0.00}{0.4} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.1} & \\mask{0.00}{1.7} & \\mask{0.00}{1.7} \\\\ \\mask{0.00}{0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.6} & \\mask{0.00}{0.6} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{0.8} & \\mask{0.00}{0.8} \\\\ \\mask{0.00}{\\shortminus 0.1} & \\mask{0.00}{1} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.5} & \\mask{0.00}{0.5} \\end{matrix}$} \n", + "\\\\[12pt]\n" + ] + } + ], + "source": [ + "for (i, π_b) in enumerate(πs):\n", + " print(\"\"\"$[\\mask{{0.0}}{{{}}},\\mask{{0.0}}{{{}}}]$ \"\"\".format(π_b[0,0], π_b[0,1]))\n", + " for (j, π_e) in enumerate(πs):\n", + " ix = i*len(πs)+j\n", + " values = df_out_all_1[ix].iloc[[2,6,4], [2,3,4]].round(1).values\n", + " print(\"\"\"& \\\\scalebox{0.8}{$\\\\begin{matrix} \"\"\"\n", + " + \"\"\"\\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} \\\\\\\\ \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} \\\\\\\\ \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}}\"\"\".format(\n", + " *[int(x) if x.is_integer() else x for x in list(values.ravel())]).replace('-', '\\\\shortminus ')\n", + " + \"\"\" \\end{matrix}$} \"\"\")\n", + " print(\"\"\"\\\\\\\\[12pt]\"\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "bb78016d-f909-4262-bec6-1b701cb3413b", + "metadata": {}, + "source": [ + "## Rs[2] setting" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "01ef0086-6dec-44a7-bd8f-9cdea919ca1e", + "metadata": {}, + "outputs": [], + "source": [ + "R, sigma = Rs[2], sigmas[2]" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "359a1e3d-aa69-445e-9041-5072ed3b7cfe", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/25 [00:00:115: RuntimeWarning: invalid value encountered in true_divide\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (N)\n", + ":118: RuntimeWarning: invalid value encountered in true_divide\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (np.sum(w* π_e[x,a] / π_b_[x,a]) + np.sum((wc* π_e[x,1-a] / π_b_[x,1-a])[c==1])),\n", + " 4%|▍ | 1/25 [00:07<02:52, 7.17s/it]:33: RuntimeWarning: invalid value encountered in double_scalars\n", + " WISs.append(np.sum(rho * r) / np.sum(rho))\n", + "100%|██████████| 25/25 [02:04<00:00, 4.98s/it]\n" + ] + } + ], + "source": [ + "df_out_all_2 = []\n", + "for π_b, π_e in tqdm(list(itertools.product(πs, πs))):\n", + " df_out = single_exp_setting(π_b, π_e)\n", + " df_out_all_2.append(df_out)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "9549c5a2-e84d-4a5d-b497-4bb85df4fff1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
π_e [1. 0.]\n", + " [1. 0.] [0. 1.]\n", + " [1. 0.] [0.5 0.5]\n", + " [1. 0. ] [0.1 0.9]\n", + " [1. 0. ] [0.8 0.2]\n", + " [1. 0. ]
π_b
[1. 0.]\n", + " [1. 0.][ 0. 0.7 0.7]\n", + " [-0.1 1.1 1.1]\n", + " [ 0. 0.7 0.7][ 1. 0.4 1.1]\n", + " [-0.3 2. 2. ]\n", + " [ 0.1 1.1 1.1][ 0.5 0.5 0.7]\n", + " [-0.2 1. 1. ]\n", + " [ 0.1 0.9 0.9][ 0.9 0.4 1. ]\n", + " [-0.3 1.7 1.8]\n", + " [ 0.1 1.1 1.1][ 0.2 0.6 0.7]\n", + " [-0.2 0.9 0.9]\n", + " [ 0. 0.8 0.8]
[0. 1.]\n", + " [1. 0.][ 0.5 0.4 0.6]\n", + " [-0.1 1.1 1.1]\n", + " [ 0. 0.7 0.7][ 0.1 1.1 1.1]\n", + " [-0.3 2. 2. ]\n", + " [ 0.1 1.1 1.1][ 0.3 0.7 0.7]\n", + " [-0.2 1. 1. ]\n", + " [ 0.1 0.9 0.9][ 0.1 1. 1. ]\n", + " [-0.3 1.7 1.8]\n", + " [ 0.1 1.1 1.1][ 0.4 0.4 0.6]\n", + " [-0.2 0.9 0.9]\n", + " [ 0. 0.8 0.8]
[0.5 0.5]\n", + " [1. 0. ][ 0. 1.1 1.1]\n", + " [-0.1 1.1 1.1]\n", + " [ 0. 0.7 0.7][ 0.1 1.8 1.8]\n", + " [-0.3 2. 2. ]\n", + " [ 0.1 1.1 1.1][ 0.1 1. 1. ]\n", + " [-0.2 1. 1. ]\n", + " [ 0.1 0.9 0.9][ 0.1 1.6 1.6]\n", + " [-0.3 1.7 1.8]\n", + " [ 0.1 1.1 1.1][ 0. 0.9 0.9]\n", + " [-0.2 0.8 0.9]\n", + " [ 0. 0.8 0.8]
[0.1 0.9]\n", + " [1. 0. ][-0. 2.4 2.4]\n", + " [-0.1 1.1 1.1]\n", + " [ 0. 0.7 0.7][ 0.1 1.2 1.2]\n", + " [-0.3 2. 2. ]\n", + " [ 0.1 1.1 1.1][ 0. 1.3 1.3]\n", + " [-0.2 1. 1. ]\n", + " [ 0.1 0.9 0.9][ 0.1 1.1 1.1]\n", + " [-0.3 1.7 1.8]\n", + " [ 0.1 1.1 1.1][ 0. 1.9 1.9]\n", + " [-0.2 0.9 0.9]\n", + " [ 0. 0.8 0.8]
[0.8 0.2]\n", + " [1. 0. ][ 0. 0.8 0.8]\n", + " [-0.1 1.1 1.1]\n", + " [ 0.1 0.7 0.7][ 0.1 3.1 3.1]\n", + " [-0.3 2. 2. ]\n", + " [ 0. 1.1 1.1][ 0. 1.5 1.5]\n", + " [-0.2 1. 1. ]\n", + " [ 0.1 0.9 0.9][ 0.1 2.8 2.8]\n", + " [-0.3 1.8 1.8]\n", + " [ 0. 1.1 1.1][ 0. 0.8 0.8]\n", + " [-0.2 0.8 0.9]\n", + " [ 0.1 0.8 0.8]
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df_tmp = pd.DataFrame(index=[str(π)[1:-1] for π in πs], columns=[str(π)[1:-1] for π in πs])\n", + "df_tmp.index.name = 'π_b'\n", + "df_tmp.columns.name = 'π_e'\n", + "for (i, π_b), (j, π_e) in itertools.product(enumerate(πs), enumerate(πs)):\n", + " ix = i*len(πs)+j\n", + " df_tmp.iloc[i,j] = str(df_out_all_2[ix].iloc[[2,6,4], [2,3,4]].round(1).values)[1:-1]\n", + "display(df_tmp.style.set_properties(**{'white-space': 'pre-wrap'}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d582afa1-895c-4b1f-afdc-4f962ec47031", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "for (i, π_b) in enumerate(πs):\n", + " print(\"\"\"$[\\mask{{0.0}}{{{}}},\\mask{{0.0}}{{{}}}]$ \"\"\".format(π_b[0,0], π_b[0,1]))\n", + " for (j, π_e) in enumerate(πs):\n", + " ix = i*len(πs)+j\n", + " values = df_out_all_2[ix].iloc[[2,6,4], [2,3,4]].round(2).values\n", + " print(\"\"\"& \\\\scalebox{0.8}{$\\\\begin{matrix} \"\"\"\n", + " + \"\"\"\\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} \\\\\\\\ \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} \\\\\\\\ \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}}\"\"\".format(*list(values.ravel())).replace('-', '\\\\shortminus ')\n", + " + \"\"\" \\end{matrix}$} \"\"\")\n", + " print(\"\"\"\\\\\\\\[12pt]\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "c135e809-4f3a-4c72-a77e-40fe7ac04209", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$[\\mask{0.0}{1.0},\\mask{0.0}{0.0}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\\\ \\mask{0.00}{\\shortminus 0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{1} & \\mask{0.00}{0.4} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{\\shortminus 0.3} & \\mask{0.00}{2} & \\mask{0.00}{2} \\\\ \\mask{0.00}{0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.5} & \\mask{0.00}{0.5} & \\mask{0.00}{0.7} \\\\ \\mask{0.00}{\\shortminus 0.2} & \\mask{0.00}{1} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0.1} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.9} & \\mask{0.00}{0.4} & \\mask{0.00}{1} \\\\ \\mask{0.00}{\\shortminus 0.3} & \\mask{0.00}{1.7} & \\mask{0.00}{1.8} \\\\ \\mask{0.00}{0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.2} & \\mask{0.00}{0.6} & \\mask{0.00}{0.7} \\\\ \\mask{0.00}{\\shortminus 0.2} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.8} & \\mask{0.00}{0.8} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.0},\\mask{0.0}{1.0}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.5} & \\mask{0.00}{0.4} & \\mask{0.00}{0.6} \\\\ \\mask{0.00}{\\shortminus 0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{\\shortminus 0.3} & \\mask{0.00}{2} & \\mask{0.00}{2} \\\\ \\mask{0.00}{0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.3} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\\\ \\mask{0.00}{\\shortminus 0.2} & \\mask{0.00}{1} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0.1} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.1} & \\mask{0.00}{1} & \\mask{0.00}{1} \\\\ \\mask{0.00}{\\shortminus 0.3} & \\mask{0.00}{1.7} & \\mask{0.00}{1.8} \\\\ \\mask{0.00}{0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.4} & \\mask{0.00}{0.4} & \\mask{0.00}{0.6} \\\\ \\mask{0.00}{\\shortminus 0.2} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.8} & \\mask{0.00}{0.8} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.5},\\mask{0.0}{0.5}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{\\shortminus 0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.1} & \\mask{0.00}{1.8} & \\mask{0.00}{1.8} \\\\ \\mask{0.00}{\\shortminus 0.3} & \\mask{0.00}{2} & \\mask{0.00}{2} \\\\ \\mask{0.00}{0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.1} & \\mask{0.00}{1} & \\mask{0.00}{1} \\\\ \\mask{0.00}{\\shortminus 0.2} & \\mask{0.00}{1} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0.1} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.1} & \\mask{0.00}{1.6} & \\mask{0.00}{1.6} \\\\ \\mask{0.00}{\\shortminus 0.3} & \\mask{0.00}{1.7} & \\mask{0.00}{1.8} \\\\ \\mask{0.00}{0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\\\ \\mask{0.00}{\\shortminus 0.2} & \\mask{0.00}{0.8} & \\mask{0.00}{0.9} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.8} & \\mask{0.00}{0.8} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.1},\\mask{0.0}{0.9}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{2.4} & \\mask{0.00}{2.4} \\\\ \\mask{0.00}{\\shortminus 0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.1} & \\mask{0.00}{1.2} & \\mask{0.00}{1.2} \\\\ \\mask{0.00}{\\shortminus 0.3} & \\mask{0.00}{2} & \\mask{0.00}{2} \\\\ \\mask{0.00}{0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1.3} & \\mask{0.00}{1.3} \\\\ \\mask{0.00}{\\shortminus 0.2} & \\mask{0.00}{1} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0.1} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{\\shortminus 0.3} & \\mask{0.00}{1.7} & \\mask{0.00}{1.8} \\\\ \\mask{0.00}{0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1.9} & \\mask{0.00}{1.9} \\\\ \\mask{0.00}{\\shortminus 0.2} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\\\ \\mask{0.00}{0} & \\mask{0.00}{0.8} & \\mask{0.00}{0.8} \\end{matrix}$} \n", + "\\\\[12pt]\n", + "$[\\mask{0.0}{0.8},\\mask{0.0}{0.2}]$ \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{0.8} & \\mask{0.00}{0.8} \\\\ \\mask{0.00}{\\shortminus 0.1} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\\\ \\mask{0.00}{0.1} & \\mask{0.00}{0.7} & \\mask{0.00}{0.7} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.1} & \\mask{0.00}{3.1} & \\mask{0.00}{3.1} \\\\ \\mask{0.00}{\\shortminus 0.3} & \\mask{0.00}{2} & \\mask{0.00}{2} \\\\ \\mask{0.00}{0} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{1.5} & \\mask{0.00}{1.5} \\\\ \\mask{0.00}{\\shortminus 0.2} & \\mask{0.00}{1} & \\mask{0.00}{1} \\\\ \\mask{0.00}{0.1} & \\mask{0.00}{0.9} & \\mask{0.00}{0.9} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0.1} & \\mask{0.00}{2.8} & \\mask{0.00}{2.8} \\\\ \\mask{0.00}{\\shortminus 0.3} & \\mask{0.00}{1.8} & \\mask{0.00}{1.8} \\\\ \\mask{0.00}{0} & \\mask{0.00}{1.1} & \\mask{0.00}{1.1} \\end{matrix}$} \n", + "& \\scalebox{0.8}{$\\begin{matrix} \\mask{0.00}{0} & \\mask{0.00}{0.8} & \\mask{0.00}{0.8} \\\\ \\mask{0.00}{\\shortminus 0.2} & \\mask{0.00}{0.8} & \\mask{0.00}{0.9} \\\\ \\mask{0.00}{0.1} & \\mask{0.00}{0.8} & \\mask{0.00}{0.8} \\end{matrix}$} \n", + "\\\\[12pt]\n" + ] + } + ], + "source": [ + "for (i, π_b) in enumerate(πs):\n", + " print(\"\"\"$[\\mask{{0.0}}{{{}}},\\mask{{0.0}}{{{}}}]$ \"\"\".format(π_b[0,0], π_b[0,1]))\n", + " for (j, π_e) in enumerate(πs):\n", + " ix = i*len(πs)+j\n", + " values = df_out_all_2[ix].iloc[[2,6,4], [2,3,4]].round(1).values\n", + " print(\"\"\"& \\\\scalebox{0.8}{$\\\\begin{matrix} \"\"\"\n", + " + \"\"\"\\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} \\\\\\\\ \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} \\\\\\\\ \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}} & \\mask{{0.00}}{{{}}}\"\"\".format(\n", + " *[int(x) if x.is_integer() else x for x in list(values.ravel())]).replace('-', '\\\\shortminus ')\n", + " + \"\"\" \\end{matrix}$} \"\"\")\n", + " print(\"\"\"\\\\\\\\[12pt]\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0fa9eb3-bec0-4429-96ae-075c203d883f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "RL_venv", + "language": "python", + "name": "rl_venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/synthetic/bandit_sweepPcannot.ipynb b/synthetic/bandit_sweepPcannot.ipynb new file mode 100644 index 0000000..4f2af65 --- /dev/null +++ b/synthetic/bandit_sweepPcannot.ipynb @@ -0,0 +1,9592 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "c7641324-81a4-4d2b-96cc-984469e52494", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from IPython.display import display\n", + "import matplotlib.pyplot as plt\n", + "%config InlineBackend.figure_formats = ['svg']\n", + "import matplotlib\n", + "matplotlib.rcParams['text.usetex'] = True\n", + "matplotlib.rcParams['font.sans-serif'] = ['FreeSans']\n", + "import seaborn as sns\n", + "import itertools\n", + "from tqdm import tqdm\n", + "import joblib" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3fedec9e-42cf-4a21-9343-4c288ccb4268", + "metadata": {}, + "outputs": [], + "source": [ + "d0 = np.array([1.])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "65192074-bcac-4807-b269-9aaccd9016df", + "metadata": {}, + "outputs": [], + "source": [ + "use_πD = False" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "c43eebdf-c021-41b4-ae98-305815f63144", + "metadata": {}, + "outputs": [], + "source": [ + "N = 200" + ] + }, + { + "cell_type": "markdown", + "id": "ba66ac61-e16a-4be9-b00a-04528c25b678", + "metadata": {}, + "source": [ + "# %Missing counterfactual annotations, constant mean split weights, no impute" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e8e28c62-746b-4b33-9e86-f72386bca996", + "metadata": {}, + "outputs": [], + "source": [ + "def single_run_pcannot(pc_annots, runs=1000, annot_std_scale=1.0):\n", + " np.random.seed(42)\n", + "\n", + " ## probability of getting a counterfactual annotation\n", + " Pc_local = pc_annots\n", + " \n", + " ww = ww_constant\n", + " \n", + " # True value of π_e\n", + " Js = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " x = rng.choice(1, size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_e[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " J = np.sum(r) / N\n", + " Js.append(J)\n", + "\n", + " # Standard IS\n", + " Gs = []\n", + " OISs = []\n", + " WISs = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " x = rng.choice(1, size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_b[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " G = np.sum(r) / N\n", + " Gs.append(G)\n", + "\n", + " if use_πD:\n", + " π_b_ = np.array([\n", + " [(np.sum((x==0)&(a==0)))/np.sum(x==0), \n", + " (np.sum((x==0)&(a==1)))/np.sum(x==0)],\n", + " ])\n", + " else:\n", + " π_b_ = π_b\n", + "\n", + " rho = π_e[x,a] / π_b_[x,a]\n", + " OISs.append(np.sum(rho * r) / N)\n", + " WISs.append(np.sum(rho * r) / np.sum(rho))\n", + "\n", + "\n", + " # Collect data using π_b - combining counterfactuals with factuals\n", + " FC_OISs_w = []\n", + " FC_WISs_w = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " rng_c = np.random.default_rng(seed=100000+seed)\n", + " x = rng.choice(1, size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_b[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " rho = π_e[x,a] / π_b[x,a]\n", + "\n", + " # counterfactual flag\n", + " c = np.array([rng_c.choice(2, p=[1-Pc_local[xi,ai], Pc_local[xi,ai]]) for xi,ai in zip(x,a)])\n", + " \n", + " # counterfactual reward\n", + " rc = np.array([rng_c.normal(R[xi,1-ai], annot_std_scale*sigma[xi,1-ai]) for xi,ai in zip(x,a)])\n", + " rc[c==0] = np.nan\n", + "\n", + " # trajectory-wise weight\n", + " w = np.ones(N)\n", + " w[c==1] = ww[x[c==1], a[c==1], a[c==1]]\n", + " wc = np.zeros(N)\n", + " wc[c==1] = ww[x[c==1], a[c==1], 1-a[c==1]]\n", + " \n", + " # print(x,a,r,c,rc,w,wc)\n", + " if use_πD:\n", + " # augmented behavior policy\n", + " π_b_ = np.array([\n", + " [(np.sum(w*((x==0)&(a==0)))+np.sum(wc*((x==0)&(a==1)&(c==1))))/(np.sum(w*(x==0))+np.sum(wc*((x==0)&(c==1)))), \n", + " (np.sum(w*((x==0)&(a==1)))+np.sum(wc*((x==0)&(a==0)&(c==1))))/(np.sum(w*(x==0))+np.sum(wc*((x==0)&(c==1))))],\n", + " ])\n", + " else:\n", + " # augmented behavior policy\n", + " π_b_ = np.array([\n", + " [(1-Pc_local[0,0])*π_b[0,0] + Pc_local[0,0]*π_b[0,0]*ww[0,0,0] + Pc_local[0,1]*π_b[0,1]*ww[0,1,0], \n", + " (1-Pc_local[0,1])*π_b[0,1] + Pc_local[0,0]*π_b[0,0]*ww[0,0,1] + Pc_local[0,1]*π_b[0,1]*ww[0,1,1]],\n", + " ])\n", + " π_b_ = π_b_ / π_b_.sum(axis=1, keepdims=True)\n", + "\n", + " # print(π_b_)\n", + " FC_OISs_w.append(\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (N)\n", + " )\n", + " FC_WISs_w.append(\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (np.sum(w* π_e[x,a] / π_b_[x,a]) + np.sum((wc* π_e[x,1-a] / π_b_[x,1-a])[c==1])),\n", + " )\n", + "\n", + " df_bias_var = []\n", + " for name, values in [\n", + " ('$\\hat{v}(\\pi_e)$', Js),\n", + " ('$\\hat{v}(\\pi_b)$', Gs),\n", + " ('OIS', OISs),\n", + " ('WIS', WISs),\n", + " ('C-OIS', FC_OISs_w),\n", + " ('C-WIS', FC_WISs_w),\n", + " ]:\n", + " df_bias_var.append([name, \n", + " np.mean(values), \n", + " np.mean(values - d0@np.sum(π_e*R, axis=1)), \n", + " np.sqrt(np.var(values)), \n", + " np.sqrt(np.mean(np.square(values - d0@np.sum(π_e*R, axis=1))))])\n", + " return pd.DataFrame(df_bias_var, columns=['Approach', 'Mean', 'Bias', 'Std', 'RMSE'])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "1b8f7a3a-4b24-44db-9c03-ac1c908a3241", + "metadata": {}, + "outputs": [], + "source": [ + "ww_constant = np.array([[\n", + " [0.5, 0.5],\n", + " [0.5, 0.5],\n", + "]])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a1c19d27-de35-4a63-bb28-fc121703fa54", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pc_annot_list = list(np.arange(0,1+1e-10,0.1).round(2))\n", + "pc_annot_list" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "0d77cfe8-c39e-4f7c-bb3b-5c06d11ce7a7", + "metadata": {}, + "outputs": [], + "source": [ + "R, sigma = np.array([[1., 2.],]), np.array([[1., 1.],])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "ebbe639a-b6dd-4f08-9b1f-8021a44d7bd4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.8 0.2]] [[0.1 0.9]]\n" + ] + } + ], + "source": [ + "π_b = np.array([[0.8, 0.2]])\n", + "π_e = np.array([[0.1, 0.9]])\n", + "print(π_b, π_e)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "0db0ccbc-2948-4833-af4e-89d6f14e107d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 121/121 [09:59<00:00, 4.95s/it]\n" + ] + } + ], + "source": [ + "df_out_sweepPcannot_1 = []\n", + "for pc0, pc1 in tqdm(list(itertools.product(pc_annot_list, pc_annot_list))):\n", + " df_out = single_run_pcannot(np.array([[pc0, pc1]]), runs=50)\n", + " df_out_sweepPcannot_1.append(df_out)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "2e873ef0-c81a-4737-ba4d-58290564b9e1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$R_0 \\sim N(1.0,1.0^2)$, $R_1 \\sim N(2.0, 1.0^2)$ \n", + " $\\pi_b=[0.8,0.2]$, $\\pi_e=[0.1,0.9]$\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-26T10:58:22.657626\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-26T10:58:22.933652\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-26T10:58:23.184669\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "variances_grid = [df__.iloc[4,3] for df__ in df_out_sweepPcannot_1]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(variances_grid).reshape((11,11)), index=pc_annot_list, columns=pc_annot_list)\n", + "df_plot.index.name = 'pc0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'pc1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "# sns.heatmap(df_plot.T, cmap='crest', square=True, ax=ax,\n", + "# cbar_kws={'label': 'Var', \"shrink\": .82}, vmin=0, vmax=4)\n", + "sns.heatmap(np.log(df_plot.T*np.sqrt(N)), cmap='mako_r', square=True, ax=ax,\n", + " cbar_kws=dict(shrink=.82, aspect=40, pad=0.04), \n", + " vmin=-0.05, vmax=1.5,\n", + " )\n", + "# ax.collections[0].colorbar.set_label('$\\log(\\mathrm{Var})$', labelpad=-9, fontsize=8)\n", + "# ax.collections[0].colorbar.set_ticks([0, 0.25])\n", + "# ax.collections[0].colorbar.set_ticklabels(['$0$', r'$\\frac{1}{4}$'])\n", + "plt.xticks([0.5,10.5], ['$0$','$1$'], rotation=0)\n", + "plt.yticks([0.5,10.5], ['$0$','$1$'])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "plt.xlabel('Fraction annotated\\n(action 0)', labelpad=-9, fontsize=9)\n", + "plt.ylabel('Fraction annotated\\n(action 1)', labelpad=-9, fontsize=9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(pc_annot_list), enumerate(pc_annot_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + " if np.isclose(df_plot.loc[w0,w1], np.nanmin(df_plot.values)):\n", + " plt.annotate('$*$', (i+0.5,j+0.5), c='k', ha='center', va='center')\n", + " # plt.plot(i+0.5,j+0.5, marker='.', c='yellow', ms=3)\n", + "\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "print('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + " R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "))\n", + "plt.savefig('fig/bandit_logvar_b82_e19_Pcannot.pdf', bbox_inches='tight')\n", + "plt.show()\n", + "# display(df_plot)\n", + "\n", + "\n", + "biases_grid = [df__.iloc[4,2] for df__ in df_out_sweepPcannot_1]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(biases_grid).reshape((11,11)), index=pc_annot_list, columns=pc_annot_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "sns.heatmap(df_plot.T, cmap='mako_r', square=True, cbar_kws={'label': 'Bias', \"shrink\": .82}, \n", + " # vmin=0, vmax=2, \n", + " ax=ax)\n", + "plt.xticks([0.5,10.5], [0,1], rotation=0)\n", + "plt.yticks([0.5,10.5], [0,1])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "# plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "# plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(pc_annot_list), enumerate(pc_annot_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + "\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "plt.show()\n", + "\n", + "\n", + "mses_grid = [df__.iloc[4,4] for df__ in df_out_sweepPcannot_1]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(mses_grid).reshape((11,11)), index=pc_annot_list, columns=pc_annot_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "sns.heatmap(df_plot.T, cmap='mako_r', square=True, cbar_kws={'label': 'MSE', \"shrink\": .82}, \n", + " vmin=0, vmax=0.35, \n", + " ax=ax)\n", + "plt.xticks([0.5,10.5], [0,1], rotation=0)\n", + "plt.yticks([0.5,10.5], [0,1])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "# plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "# plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(pc_annot_list), enumerate(pc_annot_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + "\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "a58566f1-6e85-440a-8e1a-40d946f7f3d7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
w10.00.10.20.30.40.50.60.70.80.91.0
w0
0.00.3012160.3112620.3123840.3058660.3028400.2902370.2904930.2907940.2976900.3043010.303111
0.10.2475120.2500260.2462460.2413970.2401240.2272680.2230540.2229310.2277800.2349470.231729
0.20.1982640.2030450.1993420.1914900.1909380.1824600.1832270.1781800.1846670.1854450.177459
0.30.1598400.1616460.1594250.1533760.1535690.1490860.1495330.1471110.1546030.1550650.151364
0.40.1480410.1465280.1486260.1438780.1432970.1401530.1403710.1364290.1448880.1415650.137793
0.50.1382470.1360410.1370940.1332100.1332470.1301710.1300300.1262170.1315960.1251400.122939
0.60.1230960.1200140.1200500.1170650.1178560.1140410.1124760.1104360.1177930.1125390.109014
0.70.0992190.0953930.0954560.0917220.0904240.0864600.0840880.0832640.0919860.0846970.082471
0.80.0947440.0926720.0918720.0880090.0865280.0842890.0822060.0813800.0869090.0809750.078607
0.90.0833330.0787430.0773400.0745650.0737220.0713380.0693120.0710370.0764560.0708290.070532
1.00.0775180.0741600.0728480.0696820.0677930.0683060.0658640.0666910.0705230.0672900.068661
\n", + "
" + ], + "text/plain": [ + "w1 0.0 0.1 0.2 0.3 0.4 0.5 0.6 \\\n", + "w0 \n", + "0.0 0.301216 0.311262 0.312384 0.305866 0.302840 0.290237 0.290493 \n", + "0.1 0.247512 0.250026 0.246246 0.241397 0.240124 0.227268 0.223054 \n", + "0.2 0.198264 0.203045 0.199342 0.191490 0.190938 0.182460 0.183227 \n", + "0.3 0.159840 0.161646 0.159425 0.153376 0.153569 0.149086 0.149533 \n", + "0.4 0.148041 0.146528 0.148626 0.143878 0.143297 0.140153 0.140371 \n", + "0.5 0.138247 0.136041 0.137094 0.133210 0.133247 0.130171 0.130030 \n", + "0.6 0.123096 0.120014 0.120050 0.117065 0.117856 0.114041 0.112476 \n", + "0.7 0.099219 0.095393 0.095456 0.091722 0.090424 0.086460 0.084088 \n", + "0.8 0.094744 0.092672 0.091872 0.088009 0.086528 0.084289 0.082206 \n", + "0.9 0.083333 0.078743 0.077340 0.074565 0.073722 0.071338 0.069312 \n", + "1.0 0.077518 0.074160 0.072848 0.069682 0.067793 0.068306 0.065864 \n", + "\n", + "w1 0.7 0.8 0.9 1.0 \n", + "w0 \n", + "0.0 0.290794 0.297690 0.304301 0.303111 \n", + "0.1 0.222931 0.227780 0.234947 0.231729 \n", + "0.2 0.178180 0.184667 0.185445 0.177459 \n", + "0.3 0.147111 0.154603 0.155065 0.151364 \n", + "0.4 0.136429 0.144888 0.141565 0.137793 \n", + "0.5 0.126217 0.131596 0.125140 0.122939 \n", + "0.6 0.110436 0.117793 0.112539 0.109014 \n", + "0.7 0.083264 0.091986 0.084697 0.082471 \n", + "0.8 0.081380 0.086909 0.080975 0.078607 \n", + "0.9 0.071037 0.076456 0.070829 0.070532 \n", + "1.0 0.066691 0.070523 0.067290 0.068661 " + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_plot" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "55147140-293e-4491-9d0e-814de44c4565", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "w1\n", + "0.0 0.151912\n", + "0.1 0.151775\n", + "0.2 0.150971\n", + "0.3 0.146387\n", + "0.4 0.145485\n", + "0.5 0.140346\n", + "0.6 0.139150\n", + "0.7 0.137679\n", + "0.8 0.144081\n", + "0.9 0.142072\n", + "1.0 0.139425\n", + "dtype: float64" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_plot.mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "bd2be05d-f895-479c-b3b8-e2aded1f0fff", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ApproachMeanBiasStdRMSE
0$\\hat{v}(\\pi_e)$1.860696-0.0393040.1130680.119705
1$\\hat{v}(\\pi_b)$1.162696-0.7373040.1138930.746049
2OIS1.780833-0.1191670.2232680.253080
3WIS1.860249-0.0397510.2149850.218629
4C-OIS1.780833-0.1191670.2232680.253080
5C-WIS1.860249-0.0397510.2149850.218629
\n", + "
" + ], + "text/plain": [ + " Approach Mean Bias Std RMSE\n", + "0 $\\hat{v}(\\pi_e)$ 1.860696 -0.039304 0.113068 0.119705\n", + "1 $\\hat{v}(\\pi_b)$ 1.162696 -0.737304 0.113893 0.746049\n", + "2 OIS 1.780833 -0.119167 0.223268 0.253080\n", + "3 WIS 1.860249 -0.039751 0.214985 0.218629\n", + "4 C-OIS 1.780833 -0.119167 0.223268 0.253080\n", + "5 C-WIS 1.860249 -0.039751 0.214985 0.218629" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "single_run_pcannot(np.array([[0,0]]), runs=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "e811d17b-9e4a-4250-8acb-392f70392428", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ApproachMeanBiasStdRMSE
0$\\hat{v}(\\pi_e)$1.860696-0.0393040.1130680.119705
1$\\hat{v}(\\pi_b)$1.162696-0.7373040.1138930.746049
2OIS1.780833-0.1191670.2232680.253080
3WIS1.860249-0.0397510.2149850.218629
4C-OIS1.888505-0.0114950.0699310.070870
5C-WIS1.888505-0.0114950.0699310.070870
\n", + "
" + ], + "text/plain": [ + " Approach Mean Bias Std RMSE\n", + "0 $\\hat{v}(\\pi_e)$ 1.860696 -0.039304 0.113068 0.119705\n", + "1 $\\hat{v}(\\pi_b)$ 1.162696 -0.737304 0.113893 0.746049\n", + "2 OIS 1.780833 -0.119167 0.223268 0.253080\n", + "3 WIS 1.860249 -0.039751 0.214985 0.218629\n", + "4 C-OIS 1.888505 -0.011495 0.069931 0.070870\n", + "5 C-WIS 1.888505 -0.011495 0.069931 0.070870" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "single_run_pcannot(np.array([[1,1]]), runs=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "5bacadd6-7618-4df9-9335-b32086a91b3c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ApproachMeanBiasStdRMSE
0$\\hat{v}(\\pi_e)$1.860696-0.0393040.1130680.119705
1$\\hat{v}(\\pi_b)$1.162696-0.7373040.1138930.746049
2OIS1.780833-0.1191670.2232680.253080
3WIS1.860249-0.0397510.2149850.218629
4C-OIS1.874901-0.0250990.1459730.148116
5C-WIS1.883400-0.0166000.1121950.113416
\n", + "
" + ], + "text/plain": [ + " Approach Mean Bias Std RMSE\n", + "0 $\\hat{v}(\\pi_e)$ 1.860696 -0.039304 0.113068 0.119705\n", + "1 $\\hat{v}(\\pi_b)$ 1.162696 -0.737304 0.113893 0.746049\n", + "2 OIS 1.780833 -0.119167 0.223268 0.253080\n", + "3 WIS 1.860249 -0.039751 0.214985 0.218629\n", + "4 C-OIS 1.874901 -0.025099 0.145973 0.148116\n", + "5 C-WIS 1.883400 -0.016600 0.112195 0.113416" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "single_run_pcannot(np.array([[0.5, 0.5]]), runs=10)" + ] + }, + { + "cell_type": "markdown", + "id": "49049a9e-c99d-40fc-aac7-cb4581c56d41", + "metadata": {}, + "source": [ + "# %Missing counterfactual annotations, constant mean split weights, impute annotations" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "id": "3188e7be-f29f-4eb9-aa74-7c01b0e5f8f2", + "metadata": {}, + "outputs": [], + "source": [ + "def single_run_pcannot_impute(pc_annots, runs=1000, annot_std_scale=1.0):\n", + " np.random.seed(42)\n", + "\n", + " ## probability of getting a counterfactual annotation\n", + " Pc_local = pc_annots\n", + " \n", + " ww = ww_constant\n", + " \n", + " # True value of π_e\n", + " Js = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " x = rng.choice(1, size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_e[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " J = np.sum(r) / N\n", + " Js.append(J)\n", + "\n", + " # Standard IS\n", + " Gs = []\n", + " OISs = []\n", + " WISs = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " x = rng.choice(1, size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_b[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " G = np.sum(r) / N\n", + " Gs.append(G)\n", + "\n", + " if use_πD:\n", + " π_b_ = np.array([\n", + " [(np.sum((x==0)&(a==0)))/np.sum(x==0), \n", + " (np.sum((x==0)&(a==1)))/np.sum(x==0)],\n", + " ])\n", + " else:\n", + " π_b_ = π_b\n", + "\n", + " rho = π_e[x,a] / π_b_[x,a]\n", + " OISs.append(np.sum(rho * r) / N)\n", + " WISs.append(np.sum(rho * r) / np.sum(rho))\n", + "\n", + "\n", + " # Collect data using π_b - combining counterfactuals with factuals\n", + " FC_OISs_w = []\n", + " FC_WISs_w = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " rng_c = np.random.default_rng(seed=100000+seed)\n", + " x = rng.choice(1, size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_b[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " rho = π_e[x,a] / π_b[x,a]\n", + "\n", + " # counterfactual flag\n", + " c = np.array([rng_c.choice(2, p=[1-Pc_local[xi,ai], Pc_local[xi,ai]]) for xi,ai in zip(x,a)])\n", + " # print(c)\n", + "\n", + " # counterfactual reward\n", + " rc = np.array([rng_c.normal(R[xi,1-ai], annot_std_scale*sigma[xi,1-ai]) for xi,ai in zip(x,a)])\n", + " rc[c==0] = np.nan\n", + "\n", + " # impute missing counterfactuals\n", + " rc_impute = np.full_like(rc, np.nan)\n", + " c_impute = np.full_like(c, 0)\n", + " for i, (ci,xi,ai) in enumerate(zip(c,x,a)):\n", + " if ci == 1:\n", + " pass\n", + " else:\n", + " rc_src = rc[(c==1) & (x==xi) & (a==ai)]\n", + " if len(rc_src) > 0:\n", + " rc_impute[i] = np.mean(rc_src)\n", + " c_impute[i] = 1\n", + " \n", + " rc[c==0] = rc_impute[c==0]\n", + " c[c==0] = c_impute[c==0]\n", + " \n", + " # trajectory-wise weight\n", + " w = np.ones(N)\n", + " w[c==1] = ww[x[c==1], a[c==1], a[c==1]]\n", + " wc = np.zeros(N)\n", + " wc[c==1] = ww[x[c==1], a[c==1], 1-a[c==1]]\n", + " \n", + " # print('a', a, 'r', r, 'c', c, 'rc', rc, 'w', w, 'wc', wc)\n", + " # print(np.sum(w), np.sum(wc))\n", + " \n", + " if use_πD:\n", + " # augmented behavior policy\n", + " π_b_ = np.array([\n", + " [(np.sum(w*((x==0)&(a==0)))+np.sum(wc*((x==0)&(a==1)&(c==1))))/(np.sum(w*(x==0))+np.sum(wc*((x==0)&(c==1)))), \n", + " (np.sum(w*((x==0)&(a==1)))+np.sum(wc*((x==0)&(a==0)&(c==1))))/(np.sum(w*(x==0))+np.sum(wc*((x==0)&(c==1))))],\n", + " ])\n", + " else:\n", + " # augmented behavior policy\n", + " Pc_round = (Pc_local > 0).astype(int)\n", + " π_b_ = np.array([\n", + " [(1-Pc_round[0,0])*π_b[0,0] + Pc_round[0,0]*π_b[0,0]*ww[0,0,0] + Pc_round[0,1]*π_b[0,1]*ww[0,1,0], \n", + " (1-Pc_round[0,1])*π_b[0,1] + Pc_round[0,0]*π_b[0,0]*ww[0,0,1] + Pc_round[0,1]*π_b[0,1]*ww[0,1,1]],\n", + " ])\n", + " π_b_ = π_b_ / π_b_.sum(axis=1, keepdims=True)\n", + " # print(Pc_local)\n", + " # print(Pc_round)\n", + " # print(π_b_)\n", + " \n", + " FC_OISs_w.append(\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (N)\n", + " )\n", + " FC_WISs_w.append(\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (np.sum(w* π_e[x,a] / π_b_[x,a]) + np.sum((wc* π_e[x,1-a] / π_b_[x,1-a])[c==1])),\n", + " )\n", + "\n", + " df_bias_var = []\n", + " for name, values in [\n", + " ('$\\hat{v}(\\pi_e)$', Js),\n", + " ('$\\hat{v}(\\pi_b)$', Gs),\n", + " ('OIS', OISs),\n", + " ('WIS', WISs),\n", + " ('C-OIS', FC_OISs_w),\n", + " ('C-WIS', FC_WISs_w),\n", + " ]:\n", + " df_bias_var.append([name, \n", + " np.mean(values), \n", + " np.mean(values - d0@np.sum(π_e*R, axis=1)), \n", + " np.sqrt(np.var(values)), \n", + " np.sqrt(np.mean(np.square(values - d0@np.sum(π_e*R, axis=1))))])\n", + " return pd.DataFrame(df_bias_var, columns=['Approach', 'Mean', 'Bias', 'Std', 'RMSE'])" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "id": "09d9ee9a-9803-468f-817b-78bb6208c97f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ApproachMeanBiasStdRMSE
0$\\hat{v}(\\pi_e)$1.9064690.0064690.1177660.117944
1$\\hat{v}(\\pi_b)$1.208469-0.6915310.1321540.704045
2OIS1.831194-0.0688060.2151240.225860
3WIS1.9480240.0480240.2063360.211851
4C-OIS2.5225730.6225730.9262321.116021
5C-WIS1.9765110.0765110.2334380.245657
\n", + "
" + ], + "text/plain": [ + " Approach Mean Bias Std RMSE\n", + "0 $\\hat{v}(\\pi_e)$ 1.906469 0.006469 0.117766 0.117944\n", + "1 $\\hat{v}(\\pi_b)$ 1.208469 -0.691531 0.132154 0.704045\n", + "2 OIS 1.831194 -0.068806 0.215124 0.225860\n", + "3 WIS 1.948024 0.048024 0.206336 0.211851\n", + "4 C-OIS 2.522573 0.622573 0.926232 1.116021\n", + "5 C-WIS 1.976511 0.076511 0.233438 0.245657" + ] + }, + "execution_count": 92, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "single_run_pcannot_impute(np.array([[0., 0.1]]), runs=5)" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "id": "b2951b82-96e5-428e-ac87-ac50f161aae1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.81 0.19]]\n", + "[[0.81 0.19]]\n", + "[[0.81 0.19]]\n", + "[[0.81 0.19]]\n", + "[[0.81 0.19]]\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ApproachMeanBiasStdRMSE
0$\\hat{v}(\\pi_e)$1.874441-0.0255590.0178530.031177
1$\\hat{v}(\\pi_b)$1.184441-0.7155590.0352420.716426
2OIS1.881578-0.0184220.1774020.178356
3WIS1.888907-0.0110930.0765110.077311
4C-OIS1.864218-0.0357820.1851210.188548
5C-WIS1.886975-0.0130250.0770640.078157
\n", + "
" + ], + "text/plain": [ + " Approach Mean Bias Std RMSE\n", + "0 $\\hat{v}(\\pi_e)$ 1.874441 -0.025559 0.017853 0.031177\n", + "1 $\\hat{v}(\\pi_b)$ 1.184441 -0.715559 0.035242 0.716426\n", + "2 OIS 1.881578 -0.018422 0.177402 0.178356\n", + "3 WIS 1.888907 -0.011093 0.076511 0.077311\n", + "4 C-OIS 1.864218 -0.035782 0.185121 0.188548\n", + "5 C-WIS 1.886975 -0.013025 0.077064 0.078157" + ] + }, + "execution_count": 94, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "single_run_pcannot(np.array([[0., 0.1]]), runs=5)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "669a74c3-eae8-4b39-a159-5a9906313b41", + "metadata": {}, + "outputs": [], + "source": [ + "ww_constant = np.array([[\n", + " [0.5, 0.5],\n", + " [0.5, 0.5],\n", + "]])" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "84b249de-7021-4f46-8a54-327ccd3b375e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pc_annot_list = list(np.arange(0,1+1e-10,0.1).round(2))\n", + "pc_annot_list" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "ba65a065-cc8a-43b5-80ce-f5418e96231e", + "metadata": {}, + "outputs": [], + "source": [ + "R, sigma = np.array([[1., 2.],]), np.array([[1., 1.],])" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "f1f8ceff-9579-42d3-b5f0-4d3d7a68bd2b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.8 0.2]] [[0.1 0.9]]\n" + ] + } + ], + "source": [ + "π_b = np.array([[0.8, 0.2]])\n", + "π_e = np.array([[0.1, 0.9]])\n", + "print(π_b, π_e)" + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "id": "6559e365-8bce-44d8-9f40-dfa5a111ee45", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 121/121 [10:58<00:00, 5.44s/it]\n" + ] + } + ], + "source": [ + "df_out_sweepPcannot_2 = []\n", + "for pc0, pc1 in tqdm(list(itertools.product(pc_annot_list, pc_annot_list))):\n", + " df_out = single_run_pcannot_impute(np.array([[pc0, pc1]]), runs=50)\n", + " df_out_sweepPcannot_2.append(df_out)" + ] + }, + { + "cell_type": "code", + "execution_count": 114, + "id": "a5b23916-f31d-497c-8934-7df945715f66", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$R_0 \\sim N(1.0,1.0^2)$, $R_1 \\sim N(2.0, 1.0^2)$ \n", + " $\\pi_b=[0.8,0.2]$, $\\pi_e=[0.1,0.9]$\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-26T10:58:01.654952\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-26T10:58:01.945824\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-26T10:58:02.196174\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "variances_grid = [df__.iloc[4,3] for df__ in df_out_sweepPcannot_2]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(variances_grid).reshape((11,11)), index=pc_annot_list, columns=pc_annot_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "# sns.heatmap(df_plot.T, cmap='crest', square=True, ax=ax,\n", + "# cbar_kws={'label': 'Var', \"shrink\": .82}, vmin=0, vmax=4)\n", + "sns.heatmap(np.log(df_plot.T*np.sqrt(N)), cmap='mako_r', square=True, ax=ax,\n", + " cbar_kws=dict(shrink=.82, aspect=40, pad=0.04), \n", + " vmin=-0.05, vmax=1.5,\n", + " )\n", + "# ax.collections[0].colorbar.set_label('$\\log(\\mathrm{Var})$', labelpad=-9, fontsize=8)\n", + "# ax.collections[0].colorbar.set_ticks([0, 0.25])\n", + "# ax.collections[0].colorbar.set_ticklabels(['$0$', r'$\\frac{1}{4}$'])\n", + "plt.xticks([0.5,10.5], ['$0$','$1$'], rotation=0)\n", + "plt.yticks([0.5,10.5], ['$0$','$1$'])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "plt.xlabel('Fraction annotated\\n(action 0)', labelpad=-9, fontsize=9)\n", + "plt.ylabel('Fraction annotated\\n(action 1)', labelpad=-9, fontsize=9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(pc_annot_list), enumerate(pc_annot_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + " if np.isclose(df_plot.loc[w0,w1], np.nanmin(df_plot.values)):\n", + " plt.annotate('$*$', (i+0.5,j+0.5), c='k', ha='center', va='center')\n", + " # plt.plot(i+0.5,j+0.5, marker='.', c='yellow', ms=3)\n", + "\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "print('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + " R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "))\n", + "plt.savefig('fig/bandit_logvar_b82_e19_Pcannot_impute.pdf', bbox_inches='tight')\n", + "plt.show()\n", + "# display(df_plot)\n", + "\n", + "\n", + "biases_grid = [df__.iloc[4,2] for df__ in df_out_sweepPcannot_2]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(biases_grid).reshape((11,11)), index=pc_annot_list, columns=pc_annot_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "sns.heatmap(df_plot.T, cmap='mako_r', square=True, cbar_kws={'label': 'Bias', \"shrink\": .82}, \n", + " # vmin=0, vmax=2, \n", + " ax=ax)\n", + "plt.xticks([0.5,10.5], [0,1], rotation=0)\n", + "plt.yticks([0.5,10.5], [0,1])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "# plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "# plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(pc_annot_list), enumerate(pc_annot_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + "\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "plt.show()\n", + "\n", + "\n", + "mses_grid = [df__.iloc[4,4] for df__ in df_out_sweepPcannot_2]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(mses_grid).reshape((11,11)), index=pc_annot_list, columns=pc_annot_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "sns.heatmap(df_plot.T, cmap='mako_r', square=True, cbar_kws={'label': 'MSE', \"shrink\": .82}, \n", + " vmin=0, vmax=0.35, \n", + " ax=ax)\n", + "plt.xticks([0.5,10.5], [0,1], rotation=0)\n", + "plt.yticks([0.5,10.5], [0,1])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "# plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "# plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(pc_annot_list), enumerate(pc_annot_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + "\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 107, + "id": "c6234209-0964-4790-b3d7-5ba890a39314", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
w10.00.10.20.30.40.50.60.70.80.91.0
w0
0.00.3012160.3146930.3043060.3037770.3033710.3032860.3030460.3032200.3030560.3030980.303111
0.10.1366640.1589210.1535280.1531110.1531530.1537000.1535080.1535470.1538550.1538550.153808
0.20.1057210.1310820.1216860.1215140.1214780.1219780.1217120.1216160.1218970.1218750.121903
0.30.0941480.1197360.1092560.1092260.1089360.1090900.1089250.1087300.1089830.1090840.109028
0.40.0860740.1036180.0982150.0978160.0972240.0972670.0971200.0967920.0970160.0971340.097033
0.50.0805110.0890010.0836480.0833540.0830490.0832230.0828600.0825580.0829730.0830600.083053
0.60.0776880.0801730.0753910.0749650.0749440.0750160.0745350.0742340.0746270.0747000.074654
0.70.0760640.0763650.0731440.0727260.0728380.0729530.0724940.0720930.0724970.0726020.072725
0.80.0770880.0771330.0715390.0710500.0712330.0713910.0710490.0707110.0711310.0712680.071324
0.90.0765020.0746440.0703750.0697890.0698620.0701780.0698690.0695730.0699810.0700630.070120
1.00.0775180.0739180.0691170.0686510.0686880.0688160.0684870.0681970.0685660.0686250.068661
\n", + "
" + ], + "text/plain": [ + "w1 0.0 0.1 0.2 0.3 0.4 0.5 0.6 \\\n", + "w0 \n", + "0.0 0.301216 0.314693 0.304306 0.303777 0.303371 0.303286 0.303046 \n", + "0.1 0.136664 0.158921 0.153528 0.153111 0.153153 0.153700 0.153508 \n", + "0.2 0.105721 0.131082 0.121686 0.121514 0.121478 0.121978 0.121712 \n", + "0.3 0.094148 0.119736 0.109256 0.109226 0.108936 0.109090 0.108925 \n", + "0.4 0.086074 0.103618 0.098215 0.097816 0.097224 0.097267 0.097120 \n", + "0.5 0.080511 0.089001 0.083648 0.083354 0.083049 0.083223 0.082860 \n", + "0.6 0.077688 0.080173 0.075391 0.074965 0.074944 0.075016 0.074535 \n", + "0.7 0.076064 0.076365 0.073144 0.072726 0.072838 0.072953 0.072494 \n", + "0.8 0.077088 0.077133 0.071539 0.071050 0.071233 0.071391 0.071049 \n", + "0.9 0.076502 0.074644 0.070375 0.069789 0.069862 0.070178 0.069869 \n", + "1.0 0.077518 0.073918 0.069117 0.068651 0.068688 0.068816 0.068487 \n", + "\n", + "w1 0.7 0.8 0.9 1.0 \n", + "w0 \n", + "0.0 0.303220 0.303056 0.303098 0.303111 \n", + "0.1 0.153547 0.153855 0.153855 0.153808 \n", + "0.2 0.121616 0.121897 0.121875 0.121903 \n", + "0.3 0.108730 0.108983 0.109084 0.109028 \n", + "0.4 0.096792 0.097016 0.097134 0.097033 \n", + "0.5 0.082558 0.082973 0.083060 0.083053 \n", + "0.6 0.074234 0.074627 0.074700 0.074654 \n", + "0.7 0.072093 0.072497 0.072602 0.072725 \n", + "0.8 0.070711 0.071131 0.071268 0.071324 \n", + "0.9 0.069573 0.069981 0.070063 0.070120 \n", + "1.0 0.068197 0.068566 0.068625 0.068661 " + ] + }, + "execution_count": 107, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_plot" + ] + }, + { + "cell_type": "code", + "execution_count": 108, + "id": "c156e841-fb9c-4ebf-bde8-d0277bed51b6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "w1\n", + "0.0 0.108108\n", + "0.1 0.118117\n", + "0.2 0.111837\n", + "0.3 0.111453\n", + "0.4 0.111343\n", + "0.5 0.111536\n", + "0.6 0.111237\n", + "0.7 0.111025\n", + "0.8 0.111326\n", + "0.9 0.111397\n", + "1.0 0.111402\n", + "dtype: float64" + ] + }, + "execution_count": 108, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_plot.mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "30444115-8ac8-4b84-9faa-8bb081a5dfd1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ApproachMeanBiasStdRMSE
0$\\hat{v}(\\pi_e)$1.860696-0.0393040.1130680.119705
1$\\hat{v}(\\pi_b)$1.162696-0.7373040.1138930.746049
2OIS1.780833-0.1191670.2232680.253080
3WIS1.860249-0.0397510.2149850.218629
4C-OIS1.780833-0.1191670.2232680.253080
5C-WIS1.860249-0.0397510.2149850.218629
\n", + "
" + ], + "text/plain": [ + " Approach Mean Bias Std RMSE\n", + "0 $\\hat{v}(\\pi_e)$ 1.860696 -0.039304 0.113068 0.119705\n", + "1 $\\hat{v}(\\pi_b)$ 1.162696 -0.737304 0.113893 0.746049\n", + "2 OIS 1.780833 -0.119167 0.223268 0.253080\n", + "3 WIS 1.860249 -0.039751 0.214985 0.218629\n", + "4 C-OIS 1.780833 -0.119167 0.223268 0.253080\n", + "5 C-WIS 1.860249 -0.039751 0.214985 0.218629" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "single_run_pcannot_impute(np.array([[0,0]]), runs=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "6a8ee74c-7843-456a-a638-b7116ccd6e42", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ApproachMeanBiasStdRMSE
0$\\hat{v}(\\pi_e)$1.860696-0.0393040.1130680.119705
1$\\hat{v}(\\pi_b)$1.162696-0.7373040.1138930.746049
2OIS1.780833-0.1191670.2232680.253080
3WIS1.860249-0.0397510.2149850.218629
4C-OIS1.888505-0.0114950.0699310.070870
5C-WIS1.888505-0.0114950.0699310.070870
\n", + "
" + ], + "text/plain": [ + " Approach Mean Bias Std RMSE\n", + "0 $\\hat{v}(\\pi_e)$ 1.860696 -0.039304 0.113068 0.119705\n", + "1 $\\hat{v}(\\pi_b)$ 1.162696 -0.737304 0.113893 0.746049\n", + "2 OIS 1.780833 -0.119167 0.223268 0.253080\n", + "3 WIS 1.860249 -0.039751 0.214985 0.218629\n", + "4 C-OIS 1.888505 -0.011495 0.069931 0.070870\n", + "5 C-WIS 1.888505 -0.011495 0.069931 0.070870" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "single_run_pcannot_impute(np.array([[1,1]]), runs=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "679d1d8a-5626-4ad8-9956-f9ada38b7009", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "RL_venv", + "language": "python", + "name": "rl_venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/synthetic/bandit_sweepW.ipynb b/synthetic/bandit_sweepW.ipynb new file mode 100644 index 0000000..56b8b23 --- /dev/null +++ b/synthetic/bandit_sweepW.ipynb @@ -0,0 +1,24846 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "c7641324-81a4-4d2b-96cc-984469e52494", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from IPython.display import display\n", + "import matplotlib.pyplot as plt\n", + "%config InlineBackend.figure_formats = ['svg']\n", + "import matplotlib\n", + "matplotlib.rcParams['text.usetex'] = True\n", + "matplotlib.rcParams['font.sans-serif'] = ['FreeSans']\n", + "import seaborn as sns\n", + "import itertools\n", + "from tqdm import tqdm\n", + "import joblib" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3fedec9e-42cf-4a21-9343-4c288ccb4268", + "metadata": {}, + "outputs": [], + "source": [ + "d0 = np.array([1.])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "dc809890-c83b-47a7-8862-70b7db39da30", + "metadata": {}, + "outputs": [], + "source": [ + "πs = [\n", + " np.array([[1., 0.],]),\n", + " np.array([[0., 1.],]),\n", + " np.array([[0.5, 0.5],]),\n", + " np.array([[0.1, 0.9],]),\n", + " np.array([[0.8, 0.2],]),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "65192074-bcac-4807-b269-9aaccd9016df", + "metadata": {}, + "outputs": [], + "source": [ + "use_πD = False" + ] + }, + { + "cell_type": "code", + "execution_count": 250, + "id": "c43eebdf-c021-41b4-ae98-305815f63144", + "metadata": {}, + "outputs": [], + "source": [ + "N = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 251, + "id": "256b744e-9d72-48ca-b66c-171e467bb6f5", + "metadata": {}, + "outputs": [], + "source": [ + "def single_run(runs = 1000, annot_std_scale=1.0):\n", + " np.random.seed(42)\n", + "\n", + " # True value of π_e\n", + " Js = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " x = rng.choice(1, size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_e[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " J = np.sum(r) / N\n", + " Js.append(J)\n", + "\n", + " # Standard IS\n", + " Gs = []\n", + " OISs = []\n", + " WISs = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " x = rng.choice(1, size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_b[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " G = np.sum(r) / N\n", + " Gs.append(G)\n", + "\n", + " if use_πD:\n", + " π_b_ = np.array([\n", + " [(np.sum((x==0)&(a==0)))/np.sum(x==0), \n", + " (np.sum((x==0)&(a==1)))/np.sum(x==0)],\n", + " ])\n", + " else:\n", + " π_b_ = π_b\n", + "\n", + " rho = π_e[x,a] / π_b_[x,a]\n", + " OISs.append(np.sum(rho * r) / N)\n", + " WISs.append(np.sum(rho * r) / np.sum(rho))\n", + "\n", + "\n", + " # Collect data using π_b - combining counterfactuals with factuals\n", + " FC_OISs_w = []\n", + " FC_WISs_w = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " rng_c = np.random.default_rng(seed=100000+seed)\n", + " x = rng.choice(1, size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_b[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " rho = π_e[x,a] / π_b[x,a]\n", + "\n", + " # counterfactual flag\n", + " c = np.array([rng_c.choice(2, p=[1-Pc[xi,ai], Pc[xi,ai]]) for xi,ai in zip(x,a)])\n", + "\n", + " # counterfactual reward\n", + " rc = np.array([rng_c.normal(R[xi,1-ai], annot_std_scale*sigma[xi,1-ai]) for xi,ai in zip(x,a)])\n", + " rc[c==0] = np.nan\n", + "\n", + " # trajectory-wise weight\n", + " w = np.ones(N)\n", + " w[c==1] = ww[x[c==1], a[c==1], a[c==1]]\n", + " wc = np.zeros(N)\n", + " wc[c==1] = ww[x[c==1], a[c==1], 1-a[c==1]]\n", + "\n", + " if use_πD:\n", + " # augmented behavior policy\n", + " π_b_ = np.array([\n", + " [(np.sum(w*((x==0)&(a==0)))+np.sum(wc*((x==0)&(a==1)&(c==1))))/(np.sum(w*(x==0))+np.sum(wc*((x==0)&(c==1)))), \n", + " (np.sum(w*((x==0)&(a==1)))+np.sum(wc*((x==0)&(a==0)&(c==1))))/(np.sum(w*(x==0))+np.sum(wc*((x==0)&(c==1))))],\n", + " ])\n", + " else:\n", + " # augmented behavior policy\n", + " π_b_ = np.array([\n", + " [π_b[0,0]*ww[0,0,0]+π_b[0,1]*ww[0,1,0], π_b[0,0]*ww[0,0,1]+π_b[0,1]*ww[0,1,1]],\n", + " ])\n", + " π_b_ = π_b_ / π_b_.sum(axis=1, keepdims=True)\n", + "\n", + " FC_OISs_w.append(\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (N)\n", + " )\n", + " FC_WISs_w.append(\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (np.sum(w* π_e[x,a] / π_b_[x,a]) + np.sum((wc* π_e[x,1-a] / π_b_[x,1-a])[c==1])),\n", + " )\n", + "\n", + " df_bias_var = []\n", + " for name, values in [\n", + " ('$\\hat{v}(\\pi_e)$', Js),\n", + " ('$\\hat{v}(\\pi_b)$', Gs),\n", + " ('OIS', OISs),\n", + " ('WIS', WISs),\n", + " ('C-OIS', FC_OISs_w),\n", + " ('C-WIS', FC_WISs_w),\n", + " ]:\n", + " df_bias_var.append([name, \n", + " np.mean(values), \n", + " np.mean(values - d0@np.sum(π_e*R, axis=1)), \n", + " np.sqrt(np.var(values)), \n", + " np.sqrt(np.mean(np.square(values - d0@np.sum(π_e*R, axis=1))))])\n", + " return pd.DataFrame(df_bias_var, columns=['Approach', 'Mean', 'Bias', 'Std', 'RMSE'])" + ] + }, + { + "cell_type": "markdown", + "id": "2494f08e-e610-4748-b841-1f0ce201203e", + "metadata": {}, + "source": [ + "# Ideal counterfactual annotations, sweep weights" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3e530e09-aba1-49d6-8412-5de6d6be5362", + "metadata": {}, + "outputs": [], + "source": [ + "## probability of getting a counterfactual annotation\n", + "Pc = np.array([\n", + " [1., 1.],\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1953b403-93de-476f-b403-e7e65c519366", + "metadata": {}, + "outputs": [], + "source": [ + "w_list = list(np.arange(0,1.+1e-10,0.1).round(2))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c1f3e0f5-4129-49cd-9027-6dbc7a417034", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "w_list" + ] + }, + { + "cell_type": "markdown", + "id": "bc27b096-ac34-4623-bbb8-a455859ff474", + "metadata": {}, + "source": [ + "## >>Same variance" + ] + }, + { + "cell_type": "code", + "execution_count": 389, + "id": "22152675-1f82-4c3d-9002-176a61eea9fe", + "metadata": {}, + "outputs": [], + "source": [ + "R, sigma = np.array([[1, 2.],]), np.array([[1., 1.],])" + ] + }, + { + "cell_type": "code", + "execution_count": 390, + "id": "8e070b96-8fd5-406a-87c8-0f9dfb15f060", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.1 0.9]] [[0.8 0.2]]\n" + ] + } + ], + "source": [ + "π_b = πs[3]\n", + "π_e = πs[4]\n", + "print(π_b, π_e)" + ] + }, + { + "cell_type": "code", + "execution_count": 391, + "id": "36e661be-d2a7-4774-a7e9-82d450a6ee77", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":77: RuntimeWarning: invalid value encountered in true_divide\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (N)\n", + ":80: RuntimeWarning: invalid value encountered in true_divide\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (np.sum(w* π_e[x,a] / π_b_[x,a]) + np.sum((wc* π_e[x,1-a] / π_b_[x,1-a])[c==1])),\n" + ] + } + ], + "source": [ + "df_out_sweepW_3a = []\n", + "for wwww, www_ in itertools.product(w_list, w_list):\n", + " ww = np.array([[\n", + " [wwww, 1-wwww],\n", + " [1-www_, www_],\n", + " ]])\n", + " df_out = single_run()\n", + " df_out_sweepW_3a.append(df_out)" + ] + }, + { + "cell_type": "code", + "execution_count": 392, + "id": "4404f5a6-94d1-4b4b-acfb-f1f57d9ef491", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
wwww
0.0[0.09 3.37 3.37]\n", + " [0.06 1.41 1.41][0.09 3.37 3.37]\n", + " [0.04 0.92 0.92][0.09 3.37 3.37]\n", + " [0.04 0.84 0.84][0.09 3.37 3.37]\n", + " [0.03 0.83 0.83][0.09 3.37 3.37]\n", + " [0.03 0.83 0.83][0.09 3.37 3.37]\n", + " [0.03 0.84 0.84][0.09 3.37 3.37]\n", + " [0.03 0.85 0.85][0.09 3.37 3.37]\n", + " [0.03 0.85 0.85][0.09 3.37 3.37]\n", + " [0.03 0.86 0.86][0.09 3.37 3.37]\n", + " [0.03 0.87 0.87][0.09 3.37 3.37]\n", + " [ nan nan nan]
0.1[0.09 3.37 3.37]\n", + " [0.06 1.43 1.43][0.09 3.37 3.37]\n", + " [0.04 0.91 0.91][0.09 3.37 3.37]\n", + " [0.04 0.83 0.83][0.09 3.37 3.37]\n", + " [0.03 0.82 0.82][0.09 3.37 3.37]\n", + " [0.03 0.82 0.82][0.09 3.37 3.37]\n", + " [0.03 0.82 0.82][0.09 3.37 3.37]\n", + " [0.03 0.82 0.82][0.09 3.37 3.37]\n", + " [0.03 0.81 0.81][0.09 3.37 3.37]\n", + " [0.03 0.8 0.8 ][0.09 3.37 3.37]\n", + " [0.04 0.78 0.78][0.09 3.37 3.37]\n", + " [0.09 3.47 3.47]
0.2[0.09 3.37 3.37]\n", + " [0.06 1.45 1.45][0.09 3.37 3.37]\n", + " [0.04 0.89 0.9 ][0.09 3.37 3.37]\n", + " [0.04 0.82 0.82][0.09 3.37 3.37]\n", + " [0.04 0.81 0.81][0.09 3.37 3.37]\n", + " [0.03 0.8 0.8 ][0.09 3.37 3.37]\n", + " [0.03 0.8 0.8 ][0.09 3.37 3.37]\n", + " [0.03 0.8 0.8 ][0.09 3.37 3.37]\n", + " [0.03 0.79 0.79][0.09 3.37 3.37]\n", + " [0.04 0.78 0.78][0.09 3.37 3.37]\n", + " [0.04 0.84 0.84][0.09 3.37 3.37]\n", + " [0.09 3.46 3.46]
0.3[0.09 3.37 3.37]\n", + " [0.06 1.47 1.47][0.09 3.37 3.37]\n", + " [0.04 0.88 0.88][0.09 3.37 3.37]\n", + " [0.04 0.81 0.81][0.09 3.37 3.37]\n", + " [0.04 0.8 0.8 ][0.09 3.37 3.37]\n", + " [0.03 0.79 0.8 ][0.09 3.37 3.37]\n", + " [0.03 0.79 0.79][0.09 3.37 3.37]\n", + " [0.03 0.79 0.79][0.09 3.37 3.37]\n", + " [0.04 0.78 0.78][0.09 3.37 3.37]\n", + " [0.04 0.8 0.8 ][0.09 3.37 3.37]\n", + " [0.04 0.95 0.95][0.09 3.37 3.37]\n", + " [0.09 3.45 3.45]
0.4[0.09 3.37 3.37]\n", + " [0.06 1.49 1.49][0.09 3.37 3.37]\n", + " [0.04 0.86 0.86][0.09 3.37 3.37]\n", + " [0.04 0.8 0.8 ][0.09 3.37 3.37]\n", + " [0.04 0.79 0.79][0.09 3.37 3.37]\n", + " [0.03 0.79 0.79][0.09 3.37 3.37]\n", + " [0.03 0.79 0.79][0.09 3.37 3.37]\n", + " [0.04 0.78 0.78][0.09 3.37 3.37]\n", + " [0.04 0.79 0.79][0.09 3.37 3.37]\n", + " [0.04 0.83 0.83][0.09 3.37 3.37]\n", + " [0.05 1.08 1.08][0.09 3.37 3.37]\n", + " [0.09 3.44 3.44]
0.5[0.09 3.37 3.37]\n", + " [0.06 1.5 1.51][0.09 3.37 3.37]\n", + " [0.04 0.84 0.84][0.09 3.37 3.37]\n", + " [0.04 0.79 0.8 ][0.09 3.37 3.37]\n", + " [0.04 0.79 0.79][0.09 3.37 3.37]\n", + " [0.04 0.78 0.79][0.09 3.37 3.37]\n", + " [0.04 0.78 0.78][0.09 3.37 3.37]\n", + " [0.04 0.79 0.79][0.09 3.37 3.37]\n", + " [0.04 0.8 0.81][0.09 3.37 3.37]\n", + " [0.04 0.88 0.88][0.09 3.37 3.37]\n", + " [0.05 1.21 1.21][0.09 3.37 3.37]\n", + " [0.09 3.43 3.43]
0.6[0.09 3.37 3.37]\n", + " [0.06 1.52 1.53][0.09 3.37 3.37]\n", + " [0.04 0.82 0.83][0.09 3.37 3.37]\n", + " [0.04 0.79 0.79][0.09 3.37 3.37]\n", + " [0.04 0.78 0.78][0.09 3.37 3.37]\n", + " [0.04 0.78 0.78][0.09 3.37 3.37]\n", + " [0.04 0.78 0.79][0.09 3.37 3.37]\n", + " [0.04 0.79 0.79][0.09 3.37 3.37]\n", + " [0.04 0.83 0.83][0.09 3.37 3.37]\n", + " [0.04 0.94 0.94][0.09 3.37 3.37]\n", + " [0.05 1.33 1.33][0.09 3.37 3.37]\n", + " [0.09 3.41 3.42]
0.7[0.09 3.37 3.37]\n", + " [0.06 1.54 1.55][0.09 3.37 3.37]\n", + " [0.04 0.81 0.81][0.09 3.37 3.37]\n", + " [0.04 0.78 0.79][0.09 3.37 3.37]\n", + " [0.04 0.78 0.78][0.09 3.37 3.37]\n", + " [0.04 0.78 0.78][0.09 3.37 3.37]\n", + " [0.04 0.79 0.79][0.09 3.37 3.37]\n", + " [0.04 0.81 0.81][0.09 3.37 3.37]\n", + " [0.04 0.85 0.86][0.09 3.37 3.37]\n", + " [0.04 1. 1. ][0.09 3.37 3.37]\n", + " [0.05 1.43 1.43][0.09 3.37 3.37]\n", + " [0.09 3.4 3.41]
0.8[0.09 3.37 3.37]\n", + " [0.06 1.56 1.57][0.09 3.37 3.37]\n", + " [0.04 0.79 0.79][0.09 3.37 3.37]\n", + " [0.04 0.78 0.78][0.09 3.37 3.37]\n", + " [0.04 0.78 0.78][0.09 3.37 3.37]\n", + " [0.04 0.79 0.79][0.09 3.37 3.37]\n", + " [0.04 0.8 0.8 ][0.09 3.37 3.37]\n", + " [0.04 0.82 0.82][0.09 3.37 3.37]\n", + " [0.04 0.89 0.89][0.09 3.37 3.37]\n", + " [0.05 1.05 1.06][0.09 3.37 3.37]\n", + " [0.06 1.53 1.53][0.09 3.37 3.37]\n", + " [0.09 3.39 3.39]
0.9[0.09 3.37 3.37]\n", + " [0.06 1.59 1.59][0.09 3.37 3.37]\n", + " [0.04 0.78 0.78][0.09 3.37 3.37]\n", + " [0.03 0.78 0.79][0.09 3.37 3.37]\n", + " [0.03 0.79 0.79][0.09 3.37 3.37]\n", + " [0.04 0.79 0.79][0.09 3.37 3.37]\n", + " [0.04 0.8 0.81][0.09 3.37 3.37]\n", + " [0.04 0.84 0.84][0.09 3.37 3.37]\n", + " [0.04 0.92 0.92][0.09 3.37 3.37]\n", + " [0.05 1.11 1.11][0.09 3.37 3.37]\n", + " [0.06 1.61 1.62][0.09 3.37 3.37]\n", + " [0.09 3.38 3.38]
1.0[0.09 3.37 3.37]\n", + " [ nan nan nan][0.09 3.37 3.37]\n", + " [0.03 0.79 0.79][0.09 3.37 3.37]\n", + " [0.03 0.79 0.79][0.09 3.37 3.37]\n", + " [0.03 0.79 0.79][0.09 3.37 3.37]\n", + " [0.04 0.8 0.8 ][0.09 3.37 3.37]\n", + " [0.04 0.82 0.82][0.09 3.37 3.37]\n", + " [0.04 0.86 0.86][0.09 3.37 3.37]\n", + " [0.04 0.95 0.95][0.09 3.37 3.37]\n", + " [0.05 1.17 1.17][0.09 3.37 3.37]\n", + " [0.06 1.69 1.69][0.09 3.37 3.37]\n", + " [0.09 3.37 3.37]
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df_tmp = pd.DataFrame(index=w_list, columns=w_list)\n", + "df_tmp.index.name = 'wwww'\n", + "for (i, wwww), (j, www_) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + " ix = i*len(w_list) + j\n", + " df_tmp.iloc[i, j] = str(df_out_sweepW_3a[ix].iloc[[2,4], [2,3,4]].round(2).values)[1:-1]\n", + "display(df_tmp.style.set_properties(**{'white-space': 'pre-wrap'}))" + ] + }, + { + "cell_type": "code", + "execution_count": 408, + "id": "599ac1fc-2c37-49f5-ace4-bced07c60154", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$R_0 \\sim N(1.0,1.0^2)$, $R_1 \\sim N(2.0, 1.0^2)$ \n", + " $\\pi_b=[0.8,0.2]$, $\\pi_e=[0.1,0.9]$\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-25T16:48:01.822621\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-25T16:48:02.944333\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-25T16:48:03.395759\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "variances_grid = [df__.iloc[4,3] for df__ in df_out_sweepW_3a]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(variances_grid).reshape((11,11)), index=w_list, columns=w_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "# sns.heatmap(df_plot.T, cmap='crest', square=True, ax=ax, \n", + "# cbar_kws={'label': 'Var', \"shrink\": .82}, vmin=0, vmax=4)\n", + "sns.heatmap(np.log(df_plot.T), cmap='mako_r', square=True, ax=ax,\n", + " cbar_kws=dict(shrink=.82, aspect=40, pad=0.04), \n", + " vmin=-0.3, vmax=1.3,\n", + " )\n", + "# ax.collections[0].colorbar.set_label('$\\log(\\mathrm{Var})$', labelpad=-9, fontsize=8)\n", + "ax.collections[0].colorbar.set_ticks([0,1])\n", + "plt.xticks([0.5,10.5], ['$0$','$1$'], rotation=0)\n", + "plt.yticks([0.5,10.5], ['$0$','$1$'])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='gray')\n", + " if np.isclose(df_plot.loc[w0,w1], np.nanmin(df_plot.values)):\n", + " plt.annotate('$*$', (i+0.5,j+0.5), c='k', ha='center', va='center')\n", + " # plt.plot(i+0.5,j+0.5, marker='.', c='yellow', ms=3)\n", + " if w0 == 0.5 and w1 == 0.5:\n", + " plt.plot(i+0.5,j+0.5, marker='o', mfc='none', mec='r', ms=10)\n", + "\n", + "plt.plot(10+0.5,10+0.5, marker='.', c='yellow', ms=8)\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "print('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + " R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "))\n", + "plt.savefig('fig/bandit_logvar_b19_e82.pdf', bbox_inches='tight')\n", + "plt.show()\n", + "# display(df_plot)\n", + "\n", + "\n", + "\n", + "biases_grid = [df__.iloc[4,2] for df__ in df_out_sweepW_3a]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(biases_grid).reshape((11,11)), index=w_list, columns=w_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "sns.heatmap(df_plot.T, cmap='crest', square=True, ax=ax, \n", + " cbar_kws={'label': 'Bias', \"shrink\": .82}, vmin=0, vmax=2)\n", + "plt.xticks([0.5,10.5], ['$0$','$1$'], rotation=0)\n", + "plt.yticks([0.5,10.5], ['$0$','$1$'])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + "\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "# plt.show()\n", + "\n", + "\n", + "mses_grid = [df__.iloc[4,4] for df__ in df_out_sweepW_3a]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(mses_grid).reshape((11,11)), index=w_list, columns=w_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "sns.heatmap(df_plot.T, cmap='crest', square=True, ax=ax, \n", + " cbar_kws={'label': 'MSE', \"shrink\": .82}, vmin=0, vmax=4)\n", + "plt.xticks([0.5,10.5], [0,1], rotation=0)\n", + "plt.yticks([0.5,10.5], [0,1])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + "\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 394, + "id": "fa16fd83-00ad-4c6c-9eb0-26652bb2d875", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ApproachMeanBiasStdRMSE
0$\\hat{v}(\\pi_e)$1.2667170.0667171.0873411.089386
1$\\hat{v}(\\pi_b)$1.9627170.7627171.0478621.296052
2OIS1.2882620.0882623.3715713.372726
3WIS1.9627170.7627171.0478621.296052
4C-OIS1.2882620.0882623.3715713.372726
5C-WIS1.9627170.7627171.0478621.296052
\n", + "
" + ], + "text/plain": [ + " Approach Mean Bias Std RMSE\n", + "0 $\\hat{v}(\\pi_e)$ 1.266717 0.066717 1.087341 1.089386\n", + "1 $\\hat{v}(\\pi_b)$ 1.962717 0.762717 1.047862 1.296052\n", + "2 OIS 1.288262 0.088262 3.371571 3.372726\n", + "3 WIS 1.962717 0.762717 1.047862 1.296052\n", + "4 C-OIS 1.288262 0.088262 3.371571 3.372726\n", + "5 C-WIS 1.962717 0.762717 1.047862 1.296052" + ] + }, + "execution_count": 394, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_out_sweepW_3a[-1]" + ] + }, + { + "cell_type": "markdown", + "id": "b4f0a6e9-48f6-42c0-af34-cdb896bb2ee0", + "metadata": {}, + "source": [ + "## >Same variance v2" + ] + }, + { + "cell_type": "code", + "execution_count": 401, + "id": "9c51d238-e136-436e-a1c2-9d11e0e36753", + "metadata": {}, + "outputs": [], + "source": [ + "R, sigma = np.array([[1, 2.],]), np.array([[1., 1.],])" + ] + }, + { + "cell_type": "code", + "execution_count": 402, + "id": "a1fbf515-b2b2-4642-98b8-7ed6ceefe4c6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.8 0.2]] [[0.1 0.9]]\n" + ] + } + ], + "source": [ + "π_b = np.array([[0.8, 0.2]])\n", + "π_e = np.array([[0.1, 0.9]])\n", + "print(π_b, π_e)" + ] + }, + { + "cell_type": "code", + "execution_count": 403, + "id": "ce65713a-fe7a-4d93-9686-4d5100dfd4fe", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":77: RuntimeWarning: invalid value encountered in true_divide\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (N)\n", + ":80: RuntimeWarning: invalid value encountered in true_divide\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (np.sum(w* π_e[x,a] / π_b_[x,a]) + np.sum((wc* π_e[x,1-a] / π_b_[x,1-a])[c==1])),\n" + ] + } + ], + "source": [ + "df_out_sweepW_3b = []\n", + "for wwww, www_ in itertools.product(w_list, w_list):\n", + " ww = np.array([[\n", + " [wwww, 1-wwww],\n", + " [1-www_, www_],\n", + " ]])\n", + " df_out = single_run(runs=2000)\n", + " df_out_sweepW_3b.append(df_out)" + ] + }, + { + "cell_type": "code", + "execution_count": 404, + "id": "97f574b6-5871-4eb6-95ad-0b4a539002fb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$R_0 \\sim N(1.0,1.0^2)$, $R_1 \\sim N(2.0, 1.0^2)$ \n", + " $\\pi_b=[0.8,0.2]$, $\\pi_e=[0.1,0.9]$\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-25T15:49:15.655568\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "variances_grid = [df__.iloc[4,3] for df__ in df_out_sweepW_3b]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(variances_grid).reshape((11,11)), index=w_list, columns=w_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "# sns.heatmap(df_plot.T, cmap='crest', square=True, ax=ax, \n", + "# cbar_kws={'label': 'Var', \"shrink\": .82}, vmin=0, vmax=4)\n", + "sns.heatmap(np.log(df_plot.T), cmap='mako_r', square=True, ax=ax,\n", + " cbar_kws=dict(shrink=.82, aspect=40, pad=0.04), \n", + " # vmin=-0.3, vmax=1.3,\n", + " )\n", + "# ax.collections[0].colorbar.set_label('$\\log(\\mathrm{Var})$', labelpad=-9, fontsize=8)\n", + "ax.collections[0].colorbar.set_ticks([0,1])\n", + "plt.xticks([0.5,10.5], ['$0$','$1$'], rotation=0)\n", + "plt.yticks([0.5,10.5], ['$0$','$1$'])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + " if np.isclose(df_plot.loc[w0,w1], np.nanmin(df_plot.values)):\n", + " plt.annotate('$*$', (i+0.5,j+0.5), c='k', ha='center', va='center')\n", + " # plt.plot(i+0.5,j+0.5, marker='.', c='yellow', ms=3)\n", + " if w0 == 0.5 and w1 == 0.5:\n", + " plt.plot(i+0.5,j+0.5, marker='o', mfc='none', mec='r', ms=10)\n", + "\n", + "plt.plot(10+0.5,10+0.5, marker='.', c='yellow', ms=8)\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "print('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + " R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "))\n", + "plt.savefig('fig/bandit_logvar_b82_e19.pdf', bbox_inches='tight')\n", + "plt.show()\n", + "# display(df_plot)\n", + "\n", + "\n", + "\n", + "# biases_grid = [df__.iloc[4,2] for df__ in df_out_sweepW_3b]\n", + "\n", + "# fig, ax = plt.subplots(figsize=(3,3))\n", + "# df_plot = pd.DataFrame(np.array(biases_grid).reshape((11,11)), index=w_list, columns=w_list)\n", + "# df_plot.index.name = 'w0'\n", + "# df_plot.index = np.round(df_plot.index, 5)\n", + "# df_plot.columns.name = 'w1'\n", + "# df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "# sns.heatmap(df_plot.T, cmap='crest', square=True, ax=ax, \n", + "# cbar_kws={'label': 'Bias', \"shrink\": .82}, vmin=0, vmax=2)\n", + "# plt.xticks([0.5,10.5], ['$0$','$1$'], rotation=0)\n", + "# plt.yticks([0.5,10.5], ['$0$','$1$'])\n", + "# ax.yaxis.set_ticks_position('none')\n", + "# ax.xaxis.set_ticks_position('none')\n", + "# plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "# plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "# for (i,w0), (j,w1) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + "# if np.isnan(df_plot.loc[w0,w1]):\n", + "# plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + "\n", + "# # plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# # R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# # ))\n", + "# plt.gca().invert_yaxis()\n", + "# ax.spines['top'].set_visible(True)\n", + "# ax.spines['bottom'].set_visible(True)\n", + "# ax.spines['right'].set_visible(True)\n", + "# ax.spines['left'].set_visible(True)\n", + "\n", + "# # plt.show()\n", + "\n", + "\n", + "# mses_grid = [df__.iloc[4,4] for df__ in df_out_sweepW_3b]\n", + "\n", + "# fig, ax = plt.subplots(figsize=(3,3))\n", + "# df_plot = pd.DataFrame(np.array(mses_grid).reshape((11,11)), index=w_list, columns=w_list)\n", + "# df_plot.index.name = 'w0'\n", + "# df_plot.index = np.round(df_plot.index, 5)\n", + "# df_plot.columns.name = 'w1'\n", + "# df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "# sns.heatmap(df_plot.T, cmap='crest', square=True, ax=ax, \n", + "# cbar_kws={'label': 'MSE', \"shrink\": .82}, vmin=0, vmax=4)\n", + "# plt.xticks([0.5,10.5], [0,1], rotation=0)\n", + "# plt.yticks([0.5,10.5], [0,1])\n", + "# ax.yaxis.set_ticks_position('none')\n", + "# ax.xaxis.set_ticks_position('none')\n", + "# plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "# plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "# for (i,w0), (j,w1) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + "# if np.isnan(df_plot.loc[w0,w1]):\n", + "# plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + "\n", + "# # plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# # R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# # ))\n", + "# plt.gca().invert_yaxis()\n", + "# ax.spines['top'].set_visible(True)\n", + "# ax.spines['bottom'].set_visible(True)\n", + "# ax.spines['right'].set_visible(True)\n", + "# ax.spines['left'].set_visible(True)\n", + "\n", + "# plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 405, + "id": "9c7f67ad-6362-43e9-a88c-96e3f230ea89", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ApproachMeanBiasStdRMSE
0$\\hat{v}(\\pi_e)$1.9230130.0230131.0302161.030473
1$\\hat{v}(\\pi_b)$1.235013-0.6649871.0782041.266780
2OIS2.0275120.1275124.2303174.232238
3WIS1.235013-0.6649871.0782041.266780
4C-OIS2.0275120.1275124.2303174.232238
5C-WIS1.235013-0.6649871.0782041.266780
\n", + "
" + ], + "text/plain": [ + " Approach Mean Bias Std RMSE\n", + "0 $\\hat{v}(\\pi_e)$ 1.923013 0.023013 1.030216 1.030473\n", + "1 $\\hat{v}(\\pi_b)$ 1.235013 -0.664987 1.078204 1.266780\n", + "2 OIS 2.027512 0.127512 4.230317 4.232238\n", + "3 WIS 1.235013 -0.664987 1.078204 1.266780\n", + "4 C-OIS 2.027512 0.127512 4.230317 4.232238\n", + "5 C-WIS 1.235013 -0.664987 1.078204 1.266780" + ] + }, + "execution_count": 405, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_out_sweepW_3b[-1]" + ] + }, + { + "cell_type": "markdown", + "id": "c7973f35-eac1-4fd8-9828-7a8072f78452", + "metadata": {}, + "source": [ + "## >>Annot has larger variance" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "417bd7a3-ccc7-450a-9e0b-a39b7fd18e85", + "metadata": {}, + "outputs": [], + "source": [ + "R, sigma = np.array([[1, 2.],]), np.array([[1, 1],])" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "id": "17ea8c06-d06d-43eb-8751-7823afad5b4f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.1 0.9]] [[0.8 0.2]]\n" + ] + } + ], + "source": [ + "π_b = πs[3]\n", + "π_e = πs[4]\n", + "print(π_b, π_e)" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "49f45f38-b3ec-4a94-9c87-988537de2da4", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":77: RuntimeWarning: invalid value encountered in true_divide\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (N)\n", + ":80: RuntimeWarning: invalid value encountered in true_divide\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (np.sum(w* π_e[x,a] / π_b_[x,a]) + np.sum((wc* π_e[x,1-a] / π_b_[x,1-a])[c==1])),\n" + ] + } + ], + "source": [ + "df_out_sweepW_4 = []\n", + "for wwww, www_ in itertools.product(w_list, w_list):\n", + " ww = np.array([[\n", + " [wwww, 1-wwww],\n", + " [1-www_, www_],\n", + " ]])\n", + " df_out = single_run(annot_std_scale=2.0)\n", + " df_out_sweepW_4.append(df_out)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "1846175c-bf8a-4869-b209-b0b33a9d53ac", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
wwww
0.0[0.09 3.37 3.37]\n", + " [0.09 2.24 2.25][0.09 3.37 3.37]\n", + " [0.07 1.75 1.75][0.09 3.37 3.37]\n", + " [0.06 1.66 1.66][0.09 3.37 3.37]\n", + " [0.06 1.64 1.64][0.09 3.37 3.37]\n", + " [0.06 1.63 1.63][0.09 3.37 3.37]\n", + " [0.05 1.63 1.63][0.09 3.37 3.37]\n", + " [0.05 1.63 1.63][0.09 3.37 3.37]\n", + " [0.05 1.63 1.63][0.09 3.37 3.37]\n", + " [0.05 1.63 1.63][0.09 3.37 3.37]\n", + " [0.05 1.63 1.64][0.09 3.37 3.37]\n", + " [ nan nan nan]
0.1[0.09 3.37 3.37]\n", + " [0.1 2.25 2.25][0.09 3.37 3.37]\n", + " [0.07 1.72 1.72][0.09 3.37 3.37]\n", + " [0.06 1.63 1.63][0.09 3.37 3.37]\n", + " [0.06 1.61 1.61][0.09 3.37 3.37]\n", + " [0.06 1.6 1.6 ][0.09 3.37 3.37]\n", + " [0.05 1.59 1.59][0.09 3.37 3.37]\n", + " [0.05 1.58 1.58][0.09 3.37 3.37]\n", + " [0.05 1.57 1.57][0.09 3.37 3.37]\n", + " [0.05 1.54 1.54][0.09 3.37 3.37]\n", + " [0.05 1.47 1.47][0.09 3.37 3.37]\n", + " [0.09 3.47 3.47]
0.2[0.09 3.37 3.37]\n", + " [0.1 2.25 2.25][0.09 3.37 3.37]\n", + " [0.07 1.69 1.69][0.09 3.37 3.37]\n", + " [0.06 1.61 1.61][0.09 3.37 3.37]\n", + " [0.06 1.58 1.58][0.09 3.37 3.37]\n", + " [0.06 1.57 1.57][0.09 3.37 3.37]\n", + " [0.05 1.56 1.56][0.09 3.37 3.37]\n", + " [0.05 1.54 1.54][0.09 3.37 3.37]\n", + " [0.05 1.51 1.52][0.09 3.37 3.37]\n", + " [0.05 1.47 1.47][0.09 3.37 3.37]\n", + " [0.06 1.41 1.41][0.09 3.37 3.37]\n", + " [0.09 3.46 3.46]
0.3[0.09 3.37 3.37]\n", + " [0.1 2.25 2.26][0.09 3.37 3.37]\n", + " [0.07 1.66 1.66][0.09 3.37 3.37]\n", + " [0.06 1.58 1.58][0.09 3.37 3.37]\n", + " [0.06 1.56 1.56][0.09 3.37 3.37]\n", + " [0.06 1.54 1.54][0.09 3.37 3.37]\n", + " [0.05 1.52 1.53][0.09 3.37 3.37]\n", + " [0.05 1.5 1.5 ][0.09 3.37 3.37]\n", + " [0.05 1.47 1.47][0.09 3.37 3.37]\n", + " [0.06 1.43 1.43][0.09 3.37 3.37]\n", + " [0.06 1.41 1.41][0.09 3.37 3.37]\n", + " [0.09 3.45 3.45]
0.4[0.09 3.37 3.37]\n", + " [0.1 2.26 2.26][0.09 3.37 3.37]\n", + " [0.07 1.63 1.63][0.09 3.37 3.37]\n", + " [0.06 1.56 1.56][0.09 3.37 3.37]\n", + " [0.06 1.53 1.53][0.09 3.37 3.37]\n", + " [0.06 1.52 1.52][0.09 3.37 3.37]\n", + " [0.05 1.5 1.5 ][0.09 3.37 3.37]\n", + " [0.05 1.47 1.47][0.09 3.37 3.37]\n", + " [0.06 1.44 1.44][0.09 3.37 3.37]\n", + " [0.06 1.41 1.41][0.09 3.37 3.37]\n", + " [0.06 1.45 1.45][0.09 3.37 3.37]\n", + " [0.09 3.44 3.44]
0.5[0.09 3.37 3.37]\n", + " [0.1 2.26 2.27][0.09 3.37 3.37]\n", + " [0.07 1.6 1.6 ][0.09 3.37 3.37]\n", + " [0.06 1.53 1.53][0.09 3.37 3.37]\n", + " [0.06 1.51 1.51][0.09 3.37 3.37]\n", + " [0.06 1.49 1.5 ][0.09 3.37 3.37]\n", + " [0.05 1.47 1.47][0.09 3.37 3.37]\n", + " [0.06 1.45 1.45][0.09 3.37 3.37]\n", + " [0.06 1.42 1.42][0.09 3.37 3.37]\n", + " [0.06 1.4 1.4 ][0.09 3.37 3.37]\n", + " [0.06 1.5 1.5 ][0.09 3.37 3.37]\n", + " [0.09 3.43 3.43]
0.6[0.09 3.37 3.37]\n", + " [0.1 2.27 2.27][0.09 3.37 3.37]\n", + " [0.06 1.56 1.56][0.09 3.37 3.37]\n", + " [0.06 1.51 1.51][0.09 3.37 3.37]\n", + " [0.06 1.49 1.49][0.09 3.37 3.37]\n", + " [0.05 1.47 1.47][0.09 3.37 3.37]\n", + " [0.05 1.45 1.45][0.09 3.37 3.37]\n", + " [0.06 1.43 1.43][0.09 3.37 3.37]\n", + " [0.06 1.4 1.4 ][0.09 3.37 3.37]\n", + " [0.06 1.4 1.4 ][0.09 3.37 3.37]\n", + " [0.07 1.57 1.57][0.09 3.37 3.37]\n", + " [0.09 3.42 3.42]
0.7[0.09 3.37 3.37]\n", + " [0.1 2.28 2.28][0.09 3.37 3.37]\n", + " [0.06 1.53 1.53][0.09 3.37 3.37]\n", + " [0.06 1.49 1.49][0.09 3.37 3.37]\n", + " [0.05 1.47 1.47][0.09 3.37 3.37]\n", + " [0.05 1.46 1.46][0.09 3.37 3.37]\n", + " [0.05 1.43 1.44][0.09 3.37 3.37]\n", + " [0.06 1.41 1.41][0.09 3.37 3.37]\n", + " [0.06 1.39 1.39][0.09 3.37 3.37]\n", + " [0.06 1.41 1.41][0.09 3.37 3.37]\n", + " [0.07 1.63 1.63][0.09 3.37 3.37]\n", + " [0.09 3.41 3.41]
0.8[0.09 3.37 3.37]\n", + " [0.1 2.28 2.29][0.09 3.37 3.37]\n", + " [0.06 1.5 1.5 ][0.09 3.37 3.37]\n", + " [0.05 1.47 1.47][0.09 3.37 3.37]\n", + " [0.05 1.46 1.46][0.09 3.37 3.37]\n", + " [0.05 1.44 1.44][0.09 3.37 3.37]\n", + " [0.05 1.42 1.42][0.09 3.37 3.37]\n", + " [0.06 1.4 1.4 ][0.09 3.37 3.37]\n", + " [0.06 1.39 1.39][0.09 3.37 3.37]\n", + " [0.06 1.42 1.42][0.09 3.37 3.37]\n", + " [0.07 1.69 1.7 ][0.09 3.37 3.37]\n", + " [0.09 3.39 3.4 ]
0.9[0.09 3.37 3.37]\n", + " [0.1 2.29 2.29][0.09 3.37 3.37]\n", + " [0.05 1.47 1.47][0.09 3.37 3.37]\n", + " [0.05 1.46 1.46][0.09 3.37 3.37]\n", + " [0.05 1.44 1.44][0.09 3.37 3.37]\n", + " [0.05 1.43 1.43][0.09 3.37 3.37]\n", + " [0.05 1.41 1.41][0.09 3.37 3.37]\n", + " [0.06 1.39 1.39][0.09 3.37 3.37]\n", + " [0.06 1.38 1.39][0.09 3.37 3.37]\n", + " [0.06 1.44 1.44][0.09 3.37 3.37]\n", + " [0.07 1.76 1.76][0.09 3.37 3.37]\n", + " [0.09 3.38 3.38]
1.0[0.09 3.37 3.37]\n", + " [ nan nan nan][0.09 3.37 3.37]\n", + " [0.05 1.47 1.47][0.09 3.37 3.37]\n", + " [0.05 1.45 1.45][0.09 3.37 3.37]\n", + " [0.05 1.43 1.43][0.09 3.37 3.37]\n", + " [0.05 1.41 1.41][0.09 3.37 3.37]\n", + " [0.05 1.39 1.4 ][0.09 3.37 3.37]\n", + " [0.05 1.38 1.38][0.09 3.37 3.37]\n", + " [0.06 1.39 1.39][0.09 3.37 3.37]\n", + " [0.06 1.47 1.47][0.09 3.37 3.37]\n", + " [0.07 1.81 1.81][0.09 3.37 3.37]\n", + " [0.09 3.37 3.37]
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df_tmp = pd.DataFrame(index=w_list, columns=w_list)\n", + "df_tmp.index.name = 'wwww'\n", + "for (i, wwww), (j, www_) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + " ix = i*len(w_list) + j\n", + " df_tmp.iloc[i, j] = str(df_out_sweepW_4[ix].iloc[[2,4], [2,3,4]].round(2).values)[1:-1]\n", + "display(df_tmp.style.set_properties(**{'white-space': 'pre-wrap'}))" + ] + }, + { + "cell_type": "code", + "execution_count": 409, + "id": "656090e8-98af-447c-856b-4bc5624a8f0e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$R_0 \\sim N(1.0,1.0^2)$, $R_1 \\sim N(2.0, 1.0^2)$ \n", + " $\\pi_b=[0.8,0.2]$, $\\pi_e=[0.1,0.9]$\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-25T16:48:08.881151\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-25T16:48:09.466218\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-25T16:48:10.351026\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "variances_grid = [df__.iloc[4,3] for df__ in df_out_sweepW_4]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(variances_grid).reshape((11,11)), index=w_list, columns=w_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "# sns.heatmap(df_plot.T, cmap='crest', square=True, ax=ax,\n", + "# cbar_kws={'label': 'Var', \"shrink\": .82}, \n", + "# vmin=0, vmax=4)\n", + "sns.heatmap(np.log(df_plot.T), cmap='mako_r', square=True, ax=ax,\n", + " cbar_kws=dict(shrink=.82, aspect=40, pad=0.04), \n", + " # vmin=-0.3, vmax=1.3,\n", + " )\n", + "# ax.collections[0].colorbar.set_label('$\\log(\\mathrm{Var})$', labelpad=-9, fontsize=8)\n", + "ax.collections[0].colorbar.set_ticks([0.5,1])\n", + "ax.collections[0].colorbar.set_ticklabels([r'$\\frac{1}{2}$', '$1$'])\n", + "plt.xticks([0.5,10.5], ['$0$','$1$'], rotation=0)\n", + "plt.yticks([0.5,10.5], ['$0$','$1$'])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='gray')\n", + " if np.isclose(df_plot.loc[w0,w1], np.nanmin(df_plot.values)):\n", + " plt.annotate('$*$', (i+0.5,j+0.5), c='k', ha='center', va='center')\n", + " # plt.plot(i+0.5,j+0.5, marker='.', c='yellow', ms=3)\n", + " if w0 == 0.5 and w1 == 0.5:\n", + " plt.plot(i+0.5,j+0.5, marker='o', mfc='none', mec='r', ms=10)\n", + "\n", + "plt.plot(10+0.5,10+0.5, marker='.', c='yellow', ms=8)\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "print('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + " R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "))\n", + "plt.savefig('fig/bandit_logvar_b19_e82_annotlargevar.pdf', bbox_inches='tight')\n", + "plt.show()\n", + "# display(df_plot)\n", + "\n", + "\n", + "biases_grid = [df__.iloc[4,2] for df__ in df_out_sweepW_4]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(biases_grid).reshape((11,11)), index=w_list, columns=w_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "sns.heatmap(df_plot.T, cmap='crest', square=True, cbar_kws={'label': 'Bias', \"shrink\": .82}, vmin=0, vmax=2, ax=ax)\n", + "plt.xticks([0.5,10.5], [0,1], rotation=0)\n", + "plt.yticks([0.5,10.5], [0,1])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + "\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "plt.show()\n", + "\n", + "\n", + "mses_grid = [df__.iloc[4,4] for df__ in df_out_sweepW_4]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(mses_grid).reshape((11,11)), index=w_list, columns=w_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "sns.heatmap(df_plot.T, cmap='crest', square=True, cbar_kws={'label': 'MSE', \"shrink\": .82}, vmin=0, vmax=4, ax=ax)\n", + "plt.xticks([0.5,10.5], [0,1], rotation=0)\n", + "plt.yticks([0.5,10.5], [0,1])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + "\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "9462d745-23c7-40e9-a8f7-b18c6a4b4fc5", + "metadata": {}, + "source": [ + "## >>Annot has smaller variance" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "id": "cbde037e-0653-4992-8206-ffcefdd6ae8a", + "metadata": {}, + "outputs": [], + "source": [ + "R, sigma = np.array([[1, 2.],]), np.array([[1, 1],])" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "id": "297465ff-34c2-4ec1-bc37-34fddaf6d1d6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.1 0.9]] [[0.8 0.2]]\n" + ] + } + ], + "source": [ + "π_b = πs[3]\n", + "π_e = πs[4]\n", + "print(π_b, π_e)" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "id": "576d3bed-4194-4b6c-b28c-1d61d477660c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":77: RuntimeWarning: invalid value encountered in true_divide\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (N)\n", + ":80: RuntimeWarning: invalid value encountered in true_divide\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (np.sum(w* π_e[x,a] / π_b_[x,a]) + np.sum((wc* π_e[x,1-a] / π_b_[x,1-a])[c==1])),\n" + ] + } + ], + "source": [ + "df_out_sweepW_6 = []\n", + "for wwww, www_ in itertools.product(w_list, w_list):\n", + " ww = np.array([[\n", + " [wwww, 1-wwww],\n", + " [1-www_, www_],\n", + " ]])\n", + " df_out = single_run(annot_std_scale=0.5)\n", + " df_out_sweepW_6.append(df_out)" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "id": "536055a2-3a57-4484-96cd-b8372c6fa640", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
wwww
0.0[0.09 3.37 3.37]\n", + " [0.04 1.1 1.1 ][0.09 3.37 3.37]\n", + " [0.03 0.54 0.54][0.09 3.37 3.37]\n", + " [0.02 0.44 0.44][0.09 3.37 3.37]\n", + " [0.02 0.44 0.44][0.09 3.37 3.37]\n", + " [0.02 0.45 0.45][0.09 3.37 3.37]\n", + " [0.02 0.47 0.47][0.09 3.37 3.37]\n", + " [0.02 0.48 0.48][0.09 3.37 3.37]\n", + " [0.02 0.5 0.5 ][0.09 3.37 3.37]\n", + " [0.02 0.51 0.51][0.09 3.37 3.37]\n", + " [0.02 0.52 0.52][0.09 3.37 3.37]\n", + " [ nan nan nan]
0.1[0.09 3.37 3.37]\n", + " [0.04 1.13 1.13][0.09 3.37 3.37]\n", + " [0.03 0.54 0.54][0.09 3.37 3.37]\n", + " [0.02 0.44 0.44][0.09 3.37 3.37]\n", + " [0.02 0.43 0.43][0.09 3.37 3.37]\n", + " [0.02 0.44 0.44][0.09 3.37 3.37]\n", + " [0.02 0.45 0.45][0.09 3.37 3.37]\n", + " [0.02 0.46 0.46][0.09 3.37 3.37]\n", + " [0.02 0.46 0.46][0.09 3.37 3.37]\n", + " [0.02 0.46 0.46][0.09 3.37 3.37]\n", + " [0.03 0.47 0.47][0.09 3.37 3.37]\n", + " [0.09 3.46 3.46]
0.2[0.09 3.37 3.37]\n", + " [0.04 1.15 1.15][0.09 3.37 3.37]\n", + " [0.03 0.53 0.53][0.09 3.37 3.37]\n", + " [0.02 0.44 0.44][0.09 3.37 3.37]\n", + " [0.02 0.43 0.43][0.09 3.37 3.37]\n", + " [0.02 0.44 0.44][0.09 3.37 3.37]\n", + " [0.02 0.45 0.45][0.09 3.37 3.37]\n", + " [0.02 0.45 0.45][0.09 3.37 3.37]\n", + " [0.02 0.45 0.45][0.09 3.37 3.37]\n", + " [0.03 0.47 0.47][0.09 3.37 3.37]\n", + " [0.03 0.61 0.62][0.09 3.37 3.37]\n", + " [0.09 3.45 3.45]
0.3[0.09 3.37 3.37]\n", + " [0.04 1.18 1.18][0.09 3.37 3.37]\n", + " [0.03 0.52 0.52][0.09 3.37 3.37]\n", + " [0.03 0.44 0.44][0.09 3.37 3.37]\n", + " [0.02 0.44 0.44][0.09 3.37 3.37]\n", + " [0.02 0.44 0.44][0.09 3.37 3.37]\n", + " [0.02 0.45 0.45][0.09 3.37 3.37]\n", + " [0.02 0.46 0.46][0.09 3.37 3.37]\n", + " [0.03 0.47 0.47][0.09 3.37 3.37]\n", + " [0.03 0.53 0.53][0.09 3.37 3.37]\n", + " [0.04 0.8 0.8 ][0.09 3.37 3.37]\n", + " [0.09 3.44 3.44]
0.4[0.09 3.37 3.37]\n", + " [0.04 1.21 1.21][0.09 3.37 3.37]\n", + " [0.03 0.51 0.51][0.09 3.37 3.37]\n", + " [0.03 0.44 0.44][0.09 3.37 3.37]\n", + " [0.02 0.44 0.44][0.09 3.37 3.37]\n", + " [0.02 0.45 0.45][0.09 3.37 3.37]\n", + " [0.02 0.46 0.46][0.09 3.37 3.37]\n", + " [0.03 0.47 0.47][0.09 3.37 3.37]\n", + " [0.03 0.51 0.51][0.09 3.37 3.37]\n", + " [0.03 0.61 0.61][0.09 3.37 3.37]\n", + " [0.04 0.97 0.97][0.09 3.37 3.37]\n", + " [0.09 3.43 3.44]
0.5[0.09 3.37 3.37]\n", + " [0.04 1.23 1.23][0.09 3.37 3.37]\n", + " [0.03 0.5 0.5 ][0.09 3.37 3.37]\n", + " [0.03 0.45 0.45][0.09 3.37 3.37]\n", + " [0.02 0.45 0.45][0.09 3.37 3.37]\n", + " [0.02 0.46 0.46][0.09 3.37 3.37]\n", + " [0.03 0.47 0.47][0.09 3.37 3.37]\n", + " [0.03 0.5 0.5 ][0.09 3.37 3.37]\n", + " [0.03 0.55 0.55][0.09 3.37 3.37]\n", + " [0.03 0.69 0.7 ][0.09 3.37 3.37]\n", + " [0.04 1.12 1.13][0.09 3.37 3.37]\n", + " [0.09 3.42 3.43]
0.6[0.09 3.37 3.37]\n", + " [0.04 1.26 1.26][0.09 3.37 3.37]\n", + " [0.03 0.49 0.49][0.09 3.37 3.37]\n", + " [0.03 0.45 0.45][0.09 3.37 3.37]\n", + " [0.03 0.46 0.46][0.09 3.37 3.37]\n", + " [0.03 0.47 0.47][0.09 3.37 3.37]\n", + " [0.03 0.49 0.49][0.09 3.37 3.37]\n", + " [0.03 0.53 0.53][0.09 3.37 3.37]\n", + " [0.03 0.6 0.6 ][0.09 3.37 3.37]\n", + " [0.04 0.78 0.78][0.09 3.37 3.37]\n", + " [0.05 1.26 1.26][0.09 3.37 3.37]\n", + " [0.09 3.41 3.42]
0.7[0.09 3.37 3.37]\n", + " [0.04 1.29 1.29][0.09 3.37 3.37]\n", + " [0.03 0.48 0.48][0.09 3.37 3.37]\n", + " [0.03 0.46 0.46][0.09 3.37 3.37]\n", + " [0.03 0.47 0.47][0.09 3.37 3.37]\n", + " [0.03 0.49 0.49][0.09 3.37 3.37]\n", + " [0.03 0.51 0.51][0.09 3.37 3.37]\n", + " [0.03 0.56 0.56][0.09 3.37 3.37]\n", + " [0.03 0.65 0.66][0.09 3.37 3.37]\n", + " [0.04 0.86 0.86][0.09 3.37 3.37]\n", + " [0.05 1.38 1.38][0.09 3.37 3.37]\n", + " [0.09 3.4 3.4 ]
0.8[0.09 3.37 3.37]\n", + " [0.04 1.31 1.31][0.09 3.37 3.37]\n", + " [0.03 0.47 0.47][0.09 3.37 3.37]\n", + " [0.03 0.47 0.47][0.09 3.37 3.37]\n", + " [0.03 0.49 0.49][0.09 3.37 3.37]\n", + " [0.03 0.51 0.51][0.09 3.37 3.37]\n", + " [0.03 0.54 0.54][0.09 3.37 3.37]\n", + " [0.03 0.59 0.6 ][0.09 3.37 3.37]\n", + " [0.03 0.71 0.71][0.09 3.37 3.37]\n", + " [0.04 0.94 0.94][0.09 3.37 3.37]\n", + " [0.05 1.49 1.49][0.09 3.37 3.37]\n", + " [0.09 3.39 3.39]
0.9[0.09 3.37 3.37]\n", + " [0.04 1.34 1.34][0.09 3.37 3.37]\n", + " [0.03 0.47 0.47][0.09 3.37 3.37]\n", + " [0.03 0.49 0.49][0.09 3.37 3.37]\n", + " [0.03 0.5 0.5 ][0.09 3.37 3.37]\n", + " [0.03 0.52 0.52][0.09 3.37 3.37]\n", + " [0.03 0.56 0.56][0.09 3.37 3.37]\n", + " [0.03 0.63 0.63][0.09 3.37 3.37]\n", + " [0.03 0.76 0.76][0.09 3.37 3.37]\n", + " [0.04 1.01 1.01][0.09 3.37 3.37]\n", + " [0.05 1.58 1.58][0.09 3.37 3.37]\n", + " [0.09 3.38 3.38]
1.0[0.09 3.37 3.37]\n", + " [ nan nan nan][0.09 3.37 3.37]\n", + " [0.02 0.5 0.5 ][0.09 3.37 3.37]\n", + " [0.03 0.51 0.51][0.09 3.37 3.37]\n", + " [0.03 0.52 0.52][0.09 3.37 3.37]\n", + " [0.03 0.54 0.54][0.09 3.37 3.37]\n", + " [0.03 0.59 0.59][0.09 3.37 3.37]\n", + " [0.03 0.67 0.67][0.09 3.37 3.37]\n", + " [0.04 0.81 0.81][0.09 3.37 3.37]\n", + " [0.04 1.08 1.08][0.09 3.37 3.37]\n", + " [0.05 1.66 1.66][0.09 3.37 3.37]\n", + " [0.09 3.37 3.37]
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df_tmp = pd.DataFrame(index=w_list, columns=w_list)\n", + "df_tmp.index.name = 'wwww'\n", + "for (i, wwww), (j, www_) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + " ix = i*len(w_list) + j\n", + " df_tmp.iloc[i, j] = str(df_out_sweepW_6[ix].iloc[[2,4], [2,3,4]].round(2).values)[1:-1]\n", + "display(df_tmp.style.set_properties(**{'white-space': 'pre-wrap'}))" + ] + }, + { + "cell_type": "code", + "execution_count": 410, + "id": "22f3c751-0da8-4bc6-8b46-0a04ae37d584", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$R_0 \\sim N(1.0,1.0^2)$, $R_1 \\sim N(2.0, 1.0^2)$ \n", + " $\\pi_b=[0.8,0.2]$, $\\pi_e=[0.1,0.9]$\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-25T16:48:18.109575\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-25T16:48:18.805007\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-25T16:48:19.375568\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "variances_grid = [df__.iloc[4,3] for df__ in df_out_sweepW_6]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(variances_grid).reshape((11,11)), index=w_list, columns=w_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "sns.heatmap(np.log(df_plot.T), cmap='mako_r', square=True, ax=ax,\n", + " cbar_kws=dict(shrink=.82, aspect=40, pad=0.04), \n", + " # vmin=-1, vmax=1,\n", + " )\n", + "# ax.collections[0].colorbar.set_label('$\\log(\\mathrm{Var})$', labelpad=-9, fontsize=8, ha='left')\n", + "ax.collections[0].colorbar.set_ticks([0,1])\n", + "# ax.collections[0].colorbar.set_ticklabels([r'$\\frac{1}{2}$', '$1$'])\n", + "plt.xticks([0.5,10.5], ['$0$','$1$'], rotation=0)\n", + "plt.yticks([0.5,10.5], ['$0$','$1$'])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='gray')\n", + " if np.isclose(df_plot.loc[w0,w1], np.nanmin(df_plot.values)):\n", + " plt.annotate('$*$', (i+0.5,j+0.5), c='k', ha='center', va='center')\n", + " # plt.plot(i+0.5,j+0.5, marker='.', c='yellow', ms=3)\n", + " if w0 == 0.5 and w1 == 0.5:\n", + " plt.plot(i+0.5,j+0.5, marker='o', mfc='none', mec='r', ms=10)\n", + "\n", + "plt.plot(10+0.5,10+0.5, marker='.', c='yellow', ms=8)\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "print('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + " R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "))\n", + "plt.savefig('fig/bandit_logvar_b19_e82_annotsmallvar.pdf', bbox_inches='tight')\n", + "plt.show()\n", + "# display(df_plot)\n", + "\n", + "\n", + "biases_grid = [df__.iloc[4,2] for df__ in df_out_sweepW_6]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(biases_grid).reshape((11,11)), index=w_list, columns=w_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "sns.heatmap(df_plot.T, cmap='crest', square=True, cbar_kws={'label': 'Bias', \"shrink\": .82}, vmin=0, vmax=2, ax=ax)\n", + "plt.xticks([0.5,10.5], [0,1], rotation=0)\n", + "plt.yticks([0.5,10.5], [0,1])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + "\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "plt.show()\n", + "\n", + "\n", + "mses_grid = [df__.iloc[4,4] for df__ in df_out_sweepW_6]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(mses_grid).reshape((11,11)), index=w_list, columns=w_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "sns.heatmap(df_plot.T, cmap='crest', square=True, cbar_kws={'label': 'MSE', \"shrink\": .82}, vmin=0, vmax=4, ax=ax)\n", + "plt.xticks([0.5,10.5], [0,1], rotation=0)\n", + "plt.yticks([0.5,10.5], [0,1])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + "\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "e12bdb40-5cd6-48c8-9861-2401a3d58369", + "metadata": {}, + "source": [ + "## >>Annot not useful" + ] + }, + { + "cell_type": "code", + "execution_count": 128, + "id": "44d68e22-2dce-4a6a-9801-678dd70c08fc", + "metadata": {}, + "outputs": [], + "source": [ + "R, sigma = np.array([[1, 2.],]), np.array([[1, 1],])" + ] + }, + { + "cell_type": "code", + "execution_count": 129, + "id": "2541b170-39ba-4a72-b7c9-74bfc2480790", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.9 0.1]] [[0.95 0.05]]\n" + ] + } + ], + "source": [ + "π_b = np.array([[0.9, 0.1]])\n", + "π_e = np.array([[0.95, 0.05]])\n", + "print(π_b, π_e)" + ] + }, + { + "cell_type": "code", + "execution_count": 130, + "id": "fc5abcaa-01b0-47c7-ae0f-f74e0fe30eb8", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":77: RuntimeWarning: invalid value encountered in true_divide\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (N)\n", + ":80: RuntimeWarning: invalid value encountered in true_divide\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (np.sum(w* π_e[x,a] / π_b_[x,a]) + np.sum((wc* π_e[x,1-a] / π_b_[x,1-a])[c==1])),\n" + ] + } + ], + "source": [ + "df_out_sweepW_7 = []\n", + "for wwww, www_ in itertools.product(w_list, w_list):\n", + " ww = np.array([[\n", + " [wwww, 1-wwww],\n", + " [1-www_, www_],\n", + " ]])\n", + " df_out = single_run()\n", + " df_out_sweepW_7.append(df_out)" + ] + }, + { + "cell_type": "code", + "execution_count": 131, + "id": "dfb614ae-50a7-41b6-b0aa-dcd10d76fe19", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
wwww
0.0[ 0.07 1.02 1.02]\n", + " [-0.06 3.59 3.59][ 0.07 1.02 1.02]\n", + " [-0.06 3.59 3.59][ 0.07 1.02 1.02]\n", + " [-0.06 3.59 3.59][ 0.07 1.02 1.02]\n", + " [-0.06 3.6 3.6 ][ 0.07 1.02 1.02]\n", + " [-0.06 3.6 3.6 ][ 0.07 1.02 1.02]\n", + " [-0.06 3.6 3.6 ][ 0.07 1.02 1.02]\n", + " [-0.06 3.61 3.61][ 0.07 1.02 1.02]\n", + " [-0.06 3.61 3.61][ 0.07 1.02 1.02]\n", + " [-0.06 3.61 3.61][ 0.07 1.02 1.02]\n", + " [-0.06 3.61 3.61][0.07 1.02 1.02]\n", + " [ nan nan nan]
0.1[0.07 1.02 1.02]\n", + " [0. 1.83 1.83][0.07 1.02 1.02]\n", + " [0.01 1.74 1.74][0.07 1.02 1.02]\n", + " [0.01 1.64 1.64][0.07 1.02 1.02]\n", + " [0.01 1.54 1.54][0.07 1.02 1.02]\n", + " [0.02 1.42 1.42][0.07 1.02 1.02]\n", + " [0.02 1.3 1.3 ][0.07 1.02 1.02]\n", + " [0.03 1.17 1.17][0.07 1.02 1.02]\n", + " [0.04 1.05 1.05][0.07 1.02 1.02]\n", + " [0.05 0.95 0.95][0.07 1.02 1.02]\n", + " [0.06 0.93 0.94][0.07 1.02 1.02]\n", + " [0.07 1.05 1.06]
0.2[0.07 1.02 1.02]\n", + " [0.03 1.29 1.29][0.07 1.02 1.02]\n", + " [0.03 1.23 1.23][0.07 1.02 1.02]\n", + " [0.03 1.16 1.17][0.07 1.02 1.02]\n", + " [0.04 1.1 1.1 ][0.07 1.02 1.02]\n", + " [0.04 1.05 1.05][0.07 1.02 1.02]\n", + " [0.04 0.99 0.99][0.07 1.02 1.02]\n", + " [0.05 0.95 0.95][0.07 1.02 1.02]\n", + " [0.05 0.93 0.93][0.07 1.02 1.02]\n", + " [0.06 0.93 0.94][0.07 1.02 1.02]\n", + " [0.06 0.97 0.97][0.07 1.02 1.02]\n", + " [0.07 1.05 1.06]
0.3[0.07 1.02 1.02]\n", + " [0.04 1.08 1.08][0.07 1.02 1.02]\n", + " [0.04 1.04 1.04][0.07 1.02 1.02]\n", + " [0.04 1.01 1.01][0.07 1.02 1.02]\n", + " [0.04 0.98 0.98][0.07 1.02 1.02]\n", + " [0.05 0.95 0.95][0.07 1.02 1.02]\n", + " [0.05 0.93 0.94][0.07 1.02 1.02]\n", + " [0.05 0.93 0.93][0.07 1.02 1.02]\n", + " [0.06 0.93 0.94][0.07 1.02 1.02]\n", + " [0.06 0.95 0.96][0.07 1.02 1.02]\n", + " [0.07 0.99 1. ][0.07 1.02 1.02]\n", + " [0.07 1.05 1.05]
0.4[0.07 1.02 1.02]\n", + " [0.04 0.99 0.99][0.07 1.02 1.02]\n", + " [0.05 0.97 0.97][0.07 1.02 1.02]\n", + " [0.05 0.95 0.95][0.07 1.02 1.02]\n", + " [0.05 0.94 0.94][0.07 1.02 1.02]\n", + " [0.05 0.93 0.93][0.07 1.02 1.02]\n", + " [0.06 0.93 0.93][0.07 1.02 1.02]\n", + " [0.06 0.93 0.94][0.07 1.02 1.02]\n", + " [0.06 0.95 0.95][0.07 1.02 1.02]\n", + " [0.06 0.97 0.97][0.07 1.02 1.02]\n", + " [0.07 1.01 1.01][0.07 1.02 1.02]\n", + " [0.07 1.05 1.05]
0.5[0.07 1.02 1.02]\n", + " [0.05 0.95 0.95][0.07 1.02 1.02]\n", + " [0.05 0.94 0.94][0.07 1.02 1.02]\n", + " [0.05 0.93 0.93][0.07 1.02 1.02]\n", + " [0.05 0.93 0.93][0.07 1.02 1.02]\n", + " [0.06 0.93 0.93][0.07 1.02 1.02]\n", + " [0.06 0.93 0.94][0.07 1.02 1.02]\n", + " [0.06 0.94 0.95][0.07 1.02 1.02]\n", + " [0.06 0.96 0.96][0.07 1.02 1.02]\n", + " [0.07 0.98 0.99][0.07 1.02 1.02]\n", + " [0.07 1.01 1.01][0.07 1.02 1.02]\n", + " [0.07 1.05 1.05]
0.6[0.07 1.02 1.02]\n", + " [0.05 0.93 0.93][0.07 1.02 1.02]\n", + " [0.05 0.93 0.93][0.07 1.02 1.02]\n", + " [0.06 0.93 0.93][0.07 1.02 1.02]\n", + " [0.06 0.93 0.93][0.07 1.02 1.02]\n", + " [0.06 0.93 0.94][0.07 1.02 1.02]\n", + " [0.06 0.94 0.94][0.07 1.02 1.02]\n", + " [0.06 0.95 0.96][0.07 1.02 1.02]\n", + " [0.06 0.97 0.97][0.07 1.02 1.02]\n", + " [0.07 0.99 0.99][0.07 1.02 1.02]\n", + " [0.07 1.01 1.02][0.07 1.02 1.02]\n", + " [0.07 1.04 1.05]
0.7[0.07 1.02 1.02]\n", + " [0.05 0.93 0.93][0.07 1.02 1.02]\n", + " [0.06 0.93 0.93][0.07 1.02 1.02]\n", + " [0.06 0.93 0.93][0.07 1.02 1.02]\n", + " [0.06 0.93 0.94][0.07 1.02 1.02]\n", + " [0.06 0.94 0.94][0.07 1.02 1.02]\n", + " [0.06 0.95 0.95][0.07 1.02 1.02]\n", + " [0.06 0.96 0.96][0.07 1.02 1.02]\n", + " [0.07 0.98 0.98][0.07 1.02 1.02]\n", + " [0.07 0.99 1. ][0.07 1.02 1.02]\n", + " [0.07 1.02 1.02][0.07 1.02 1.02]\n", + " [0.07 1.04 1.04]
0.8[0.07 1.02 1.02]\n", + " [0.06 0.93 0.93][0.07 1.02 1.02]\n", + " [0.06 0.93 0.93][0.07 1.02 1.02]\n", + " [0.06 0.93 0.94][0.07 1.02 1.02]\n", + " [0.06 0.94 0.94][0.07 1.02 1.02]\n", + " [0.06 0.95 0.95][0.07 1.02 1.02]\n", + " [0.06 0.96 0.96][0.07 1.02 1.02]\n", + " [0.06 0.97 0.97][0.07 1.02 1.02]\n", + " [0.07 0.98 0.98][0.07 1.02 1.02]\n", + " [0.07 1. 1. ][0.07 1.02 1.02]\n", + " [0.07 1.01 1.02][0.07 1.02 1.02]\n", + " [0.07 1.03 1.04]
0.9[0.07 1.02 1.02]\n", + " [0.06 0.93 0.93][0.07 1.02 1.02]\n", + " [0.06 0.93 0.94][0.07 1.02 1.02]\n", + " [0.06 0.94 0.94][0.07 1.02 1.02]\n", + " [0.06 0.94 0.95][0.07 1.02 1.02]\n", + " [0.06 0.95 0.95][0.07 1.02 1.02]\n", + " [0.06 0.96 0.96][0.07 1.02 1.02]\n", + " [0.06 0.97 0.97][0.07 1.02 1.02]\n", + " [0.06 0.98 0.98][0.07 1.02 1.02]\n", + " [0.07 0.99 0.99][0.07 1.02 1.02]\n", + " [0.07 1.01 1.01][0.07 1.02 1.02]\n", + " [0.07 1.02 1.03]
1.0[0.07 1.02 1.02]\n", + " [ nan nan nan][0.07 1.02 1.02]\n", + " [0.05 0.99 0.99][0.07 1.02 1.02]\n", + " [0.05 0.99 0.99][0.07 1.02 1.02]\n", + " [0.06 0.98 0.99][0.07 1.02 1.02]\n", + " [0.06 0.98 0.99][0.07 1.02 1.02]\n", + " [0.06 0.98 0.99][0.07 1.02 1.02]\n", + " [0.06 0.99 0.99][0.07 1.02 1.02]\n", + " [0.06 0.99 0.99][0.07 1.02 1.02]\n", + " [0.06 1. 1. ][0.07 1.02 1.02]\n", + " [0.06 1.01 1.01][0.07 1.02 1.02]\n", + " [0.07 1.02 1.02]
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df_tmp = pd.DataFrame(index=w_list, columns=w_list)\n", + "df_tmp.index.name = 'wwww'\n", + "for (i, wwww), (j, www_) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + " ix = i*len(w_list) + j\n", + " df_tmp.iloc[i, j] = str(df_out_sweepW_7[ix].iloc[[2,4], [2,3,4]].round(2).values)[1:-1]\n", + "display(df_tmp.style.set_properties(**{'white-space': 'pre-wrap'}))" + ] + }, + { + "cell_type": "code", + "execution_count": 411, + "id": "9bdf454c-e5fa-45bf-a2c9-b5c938155c8b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$R_0 \\sim N(1.0,1.0^2)$, $R_1 \\sim N(2.0, 1.0^2)$ \n", + " $\\pi_b=[0.8,0.2]$, $\\pi_e=[0.1,0.9]$\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-25T16:48:27.335978\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-25T16:48:29.037954\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-25T16:48:30.279476\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "variances_grid = [df__.iloc[4,3] for df__ in df_out_sweepW_7]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(variances_grid).reshape((11,11)), index=w_list, columns=w_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "# sns.heatmap(df_plot.T, cmap='crest', square=True, ax=ax,\n", + "# cbar_kws={'label': 'Var', \"shrink\": .82}, vmin=0, vmax=4)\n", + "sns.heatmap(np.log(df_plot.T), cmap='mako_r', square=True, ax=ax,\n", + " cbar_kws=dict(shrink=.82, aspect=40, pad=0.04), \n", + " # vmin=-0.3, vmax=1.3,\n", + " )\n", + "# ax.collections[0].colorbar.set_label('$\\log(\\mathrm{Var})$', labelpad=-9, fontsize=8)\n", + "ax.collections[0].colorbar.set_ticks([0,1])\n", + "plt.xticks([0.5,10.5], ['$0$','$1$'], rotation=0)\n", + "plt.yticks([0.5,10.5], ['$0$','$1$'])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='gray')\n", + " if np.isclose(df_plot.loc[w0,w1], np.nanmin(df_plot.values)):\n", + " plt.annotate('$*$', (i+0.5,j+0.5), c='k', ha='center', va='center')\n", + " if w0 == 0.5 and w1 == 0.5:\n", + " plt.plot(i+0.5,j+0.5, marker='o', mfc='none', mec='r', ms=10)\n", + "\n", + "plt.plot(10+0.5,10+0.5, marker='.', c='yellow', ms=8)\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "print('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + " R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "))\n", + "plt.savefig('fig/bandit_logvar_b91_e9505.pdf', bbox_inches='tight')\n", + "plt.show()\n", + "# display(df_plot)\n", + "\n", + "\n", + "biases_grid = [df__.iloc[4,2] for df__ in df_out_sweepW_7]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(biases_grid).reshape((11,11)), index=w_list, columns=w_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "sns.heatmap(df_plot.T, cmap='crest', square=True, cbar_kws={'label': 'Bias', \"shrink\": .82}, vmin=0, vmax=2, ax=ax)\n", + "plt.xticks([0.5,10.5], [0,1], rotation=0)\n", + "plt.yticks([0.5,10.5], [0,1])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + "\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "plt.show()\n", + "\n", + "\n", + "mses_grid = [df__.iloc[4,4] for df__ in df_out_sweepW_7]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(mses_grid).reshape((11,11)), index=w_list, columns=w_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "sns.heatmap(df_plot.T, cmap='crest', square=True, cbar_kws={'label': 'MSE', \"shrink\": .82}, vmin=0, vmax=4, ax=ax)\n", + "plt.xticks([0.5,10.5], [0,1], rotation=0)\n", + "plt.yticks([0.5,10.5], [0,1])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(w_list), enumerate(w_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + "\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "4303c58f-1007-4899-9f37-c65e6d55e178", + "metadata": {}, + "source": [ + "# Ideal counterfactual annotations, non-constant weights mean split" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77e1563d-c442-47f9-acae-6f6ebe12fc52", + "metadata": {}, + "outputs": [], + "source": [ + "def single_run_Wrange(W_ranges, runs=1000, annot_std_scale=1.0):\n", + " np.random.seed(42)\n", + "\n", + " # True value of π_e\n", + " Js = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " x = rng.choice(1, size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_e[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " J = np.sum(r) / N\n", + " Js.append(J)\n", + "\n", + " # Standard IS\n", + " Gs = []\n", + " OISs = []\n", + " WISs = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " x = rng.choice(1, size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_b[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " G = np.sum(r) / N\n", + " Gs.append(G)\n", + "\n", + " if use_πD:\n", + " π_b_ = np.array([\n", + " [(np.sum((x==0)&(a==0)))/np.sum(x==0), \n", + " (np.sum((x==0)&(a==1)))/np.sum(x==0)],\n", + " ])\n", + " else:\n", + " π_b_ = π_b\n", + "\n", + " rho = π_e[x,a] / π_b_[x,a]\n", + " OISs.append(np.sum(rho * r) / N)\n", + " WISs.append(np.sum(rho * r) / np.sum(rho))\n", + "\n", + "\n", + " # Collect data using π_b - combining counterfactuals with factuals\n", + " FC_OISs_w = []\n", + " FC_WISs_w = []\n", + " for seed in range(runs):\n", + " rng = np.random.default_rng(seed=10+seed)\n", + " rng_c = np.random.default_rng(seed=100000+seed)\n", + " x = rng.choice(1, size=N, p=d0)\n", + " a = np.array([rng.choice(2, p=π_b[xi]) for xi in x])\n", + " r = np.array([rng.normal(R[xi,ai], sigma[xi,ai]) for xi,ai in zip(x,a)])\n", + " rho = π_e[x,a] / π_b[x,a]\n", + "\n", + " # counterfactual flag\n", + " c = np.array([rng_c.choice(2, p=[1-Pc[xi,ai], Pc[xi,ai]]) for xi,ai in zip(x,a)])\n", + "\n", + " # counterfactual reward\n", + " rc = np.array([rng_c.normal(R[xi,1-ai], annot_std_scale*sigma[xi,1-ai]) for xi,ai in zip(x,a)])\n", + " rc[c==0] = np.nan\n", + "\n", + " # trajectory-wise weight\n", + "# w = np.ones(N)\n", + "# w[c==1] = ww[x[c==1], a[c==1], a[c==1]]\n", + "# wc = np.zeros(N)\n", + "# wc[c==1] = ww[x[c==1], a[c==1], 1-a[c==1]]\n", + " \n", + " w = np.ones(N)\n", + " w[c==1] = np.random.uniform(ww_mean[x[c==1], a[c==1], a[c==1]] - W_ranges[x[c==1], a[c==1]]/2, \n", + " ww_mean[x[c==1], a[c==1], a[c==1]] + W_ranges[x[c==1], a[c==1]]/2)\n", + " # w[c==1] = w[c==1]/np.mean(w[c==1])*ww_mean[x[c==1], a[c==1], a[c==1]] ### renormalize such that mean factual weight is exactly ww_mean\n", + " wc = np.zeros(N)\n", + " wc[c==1] = 1 - w[c==1] #1 - ww[x[c==1], a[c==1]]\n", + " \n", + " if use_πD:\n", + " # augmented behavior policy\n", + " π_b_ = np.array([\n", + " [(np.sum(w*((x==0)&(a==0)))+np.sum(wc*((x==0)&(a==1)&(c==1))))/(np.sum(w*(x==0))+np.sum(wc*((x==0)&(c==1)))), \n", + " (np.sum(w*((x==0)&(a==1)))+np.sum(wc*((x==0)&(a==0)&(c==1))))/(np.sum(w*(x==0))+np.sum(wc*((x==0)&(c==1))))],\n", + " ])\n", + " else:\n", + " # augmented behavior policy\n", + " π_b_ = np.array([\n", + " [π_b[0,0]*ww_mean[0,0,0]+π_b[0,1]*ww_mean[0,1,0], π_b[0,0]*ww_mean[0,0,1]+π_b[0,1]*ww_mean[0,1,1]],\n", + " ])\n", + " π_b_ = π_b_ / π_b_.sum(axis=1, keepdims=True)\n", + "\n", + " FC_OISs_w.append(\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (N)\n", + " )\n", + " FC_WISs_w.append(\n", + " (np.sum(w* π_e[x,a] / π_b_[x,a] * r) + np.nansum(wc* π_e[x,1-a] / π_b_[x,1-a] * rc)) / (np.sum(w* π_e[x,a] / π_b_[x,a]) + np.sum((wc* π_e[x,1-a] / π_b_[x,1-a])[c==1])),\n", + " )\n", + "\n", + " df_bias_var = []\n", + " for name, values in [\n", + " ('$\\hat{v}(\\pi_e)$', Js),\n", + " ('$\\hat{v}(\\pi_b)$', Gs),\n", + " ('OIS', OISs),\n", + " ('WIS', WISs),\n", + " ('C-OIS', FC_OISs_w),\n", + " ('C-WIS', FC_WISs_w),\n", + " ]:\n", + " df_bias_var.append([name, \n", + " np.mean(values), \n", + " np.mean(values - d0@np.sum(π_e*R, axis=1)), \n", + " np.sqrt(np.var(values)), \n", + " np.sqrt(np.mean(np.square(values - d0@np.sum(π_e*R, axis=1))))])\n", + " return pd.DataFrame(df_bias_var, columns=['Approach', 'Mean', 'Bias', 'Std', 'RMSE'])" + ] + }, + { + "cell_type": "code", + "execution_count": 192, + "id": "6594e5ba-4060-4225-b18e-f5725c9c213e", + "metadata": {}, + "outputs": [], + "source": [ + "## probability of getting a counterfactual annotation\n", + "Pc = np.array([\n", + " [1., 1.],\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 290, + "id": "6be2ca1d-b430-436c-b674-8c5446d347b5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]" + ] + }, + "execution_count": 290, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Unf dist, mean 0.5, range sweep from 0 to 1\n", + "ww_mean = np.array([[\n", + " [0.5, 0.5],\n", + " [0.5, 0.5],\n", + "]])\n", + "ww_range_list = list(np.arange(0,1+1e-10,0.1).round(2))\n", + "ww_range_list" + ] + }, + { + "cell_type": "markdown", + "id": "4b1bcc2c-8e5c-4240-9d81-2fca4ad20c58", + "metadata": {}, + "source": [ + "## >>Policy combination 1" + ] + }, + { + "cell_type": "code", + "execution_count": 291, + "id": "5d12cca8-c4ea-4bb4-b4c5-8d170096fc13", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.8 0.2]] [[0.1 0.9]]\n" + ] + } + ], + "source": [ + "π_b = np.array([[0.8, 0.2]])\n", + "π_e = np.array([[0.1, 0.9]])\n", + "print(π_b, π_e)" + ] + }, + { + "cell_type": "code", + "execution_count": 292, + "id": "79f3eea3-e1a0-4314-93bd-ca0c273a53c6", + "metadata": {}, + "outputs": [], + "source": [ + "df_out_sweepWrange = []\n", + "for w0_range, w1_range in itertools.product(ww_range_list, ww_range_list):\n", + " df_out = single_run_Wrange(np.array([[w0_range, w1_range]]))\n", + " df_out_sweepWrange.append(df_out)" + ] + }, + { + "cell_type": "code", + "execution_count": 352, + "id": "49456660-d39c-4d43-8d82-ce0763bb67eb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "$R_0 \\sim N(1.0,1^2)$, $R_1 \\sim N(2.0, 1^2)$ \n", + " $\\pi_b=[0.8,0.2]$, $\\pi_e=[0.1,0.9]$\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-25T14:45:33.331570\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-25T14:45:33.626845\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2023-04-25T14:45:33.946866\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.3.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "variances_grid = [df__.iloc[4,3] for df__ in df_out_sweepWrange]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(variances_grid).reshape((11,11)), index=ww_range_list, columns=ww_range_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "# sns.heatmap(df_plot.T, cmap='crest', square=True, ax=ax,\n", + "# cbar_kws={'label': 'Var', \"shrink\": .82}, vmin=0, vmax=4)\n", + "sns.heatmap(np.log(df_plot.T), cmap='mako_r', square=True, ax=ax,\n", + " cbar_kws=dict(shrink=.82, aspect=40, pad=0.04), \n", + " # vmin=-0.3, vmax=1.3,\n", + " )\n", + "# ax.collections[0].colorbar.set_label('$\\log(\\mathrm{Var})$', labelpad=-9, fontsize=8)\n", + "ax.collections[0].colorbar.set_ticks([0, 0.25])\n", + "ax.collections[0].colorbar.set_ticklabels(['$0$', r'$\\frac{1}{4}$'])\n", + "plt.xticks([0.5,10.5], ['$0$','$1$'], rotation=0)\n", + "plt.yticks([0.5,10.5], ['$0$','$1$'])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "plt.xlabel('$\\\\mathrm{range}(W(\\\\cdot|s,a=0))$', labelpad=-9)\n", + "plt.ylabel('$\\\\mathrm{range}(W(\\\\cdot|s,a=1))$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(ww_range_list), enumerate(ww_range_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='gray')\n", + " if np.isclose(df_plot.loc[w0,w1], np.nanmin(df_plot.values)):\n", + " plt.annotate('$*$', (i+0.5,j+0.5), c='k', ha='center', va='center')\n", + " if w0 == 0. and w1 == 0.:\n", + " plt.plot(i+0.5,j+0.5, marker='o', mfc='none', mec='r', ms=10)\n", + "\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "print('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + " R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "))\n", + "plt.savefig('fig/bandit_logvar_b82_e19_wrange.pdf', bbox_inches='tight')\n", + "plt.show()\n", + "# display(df_plot)\n", + "\n", + "\n", + "biases_grid = [df__.iloc[4,2] for df__ in df_out_sweepWrange]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(biases_grid).reshape((11,11)), index=ww_range_list, columns=ww_range_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "sns.heatmap(df_plot.T, cmap='crest', square=True, cbar_kws={'label': 'Bias', \"shrink\": .82}, vmin=0, vmax=2, ax=ax)\n", + "plt.xticks([0.5,10.5], [0,1], rotation=0)\n", + "plt.yticks([0.5,10.5], [0,1])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(ww_range_list), enumerate(ww_range_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + "\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "plt.show()\n", + "\n", + "\n", + "mses_grid = [df__.iloc[4,4] for df__ in df_out_sweepWrange]\n", + "\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "df_plot = pd.DataFrame(np.array(mses_grid).reshape((11,11)), index=ww_range_list, columns=ww_range_list)\n", + "df_plot.index.name = 'w0'\n", + "df_plot.index = np.round(df_plot.index, 5)\n", + "df_plot.columns.name = 'w1'\n", + "df_plot.columns = np.round(df_plot.columns, 5)\n", + "\n", + "sns.heatmap(df_plot.T, cmap='crest', square=True, cbar_kws={'label': 'MSE', \"shrink\": .82}, \n", + " # vmin=0, vmax=4, \n", + " ax=ax)\n", + "plt.xticks([0.5,10.5], [0,1], rotation=0)\n", + "plt.yticks([0.5,10.5], [0,1])\n", + "ax.yaxis.set_ticks_position('none')\n", + "ax.xaxis.set_ticks_position('none')\n", + "plt.xlabel('$W(\\\\tilde{a}=0|s,a=0)$', labelpad=-9)\n", + "plt.ylabel('$W(\\\\tilde{a}=1|s,a=1)$', labelpad=-9)\n", + "\n", + "for (i,w0), (j,w1) in itertools.product(enumerate(ww_range_list), enumerate(ww_range_list)):\n", + " if np.isnan(df_plot.loc[w0,w1]):\n", + " plt.plot(i+0.5,j+0.5, marker='x', c='r')\n", + "\n", + "# plt.title('$R_0 \\sim N({},{}^2)$, $R_1 \\sim N({}, {}^2)$ \\n $\\pi_b=[{},{}]$, $\\pi_e=[{},{}]$'.format(\n", + "# R[0,0], sigma[0,0], R[0,1], sigma[0,1], π_b[0,0], π_b[0,1], π_e[0,0], π_e[0,1], \n", + "# ))\n", + "plt.gca().invert_yaxis()\n", + "ax.spines['top'].set_visible(True)\n", + "ax.spines['bottom'].set_visible(True)\n", + "ax.spines['right'].set_visible(True)\n", + "ax.spines['left'].set_visible(True)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "679d1d8a-5626-4ad8-9956-f9ada38b7009", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "RL_venv", + "language": "python", + "name": "rl_venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}