diff --git a/nbs/explore_huggingface_datasets.ipynb b/nbs/explore_huggingface_datasets.ipynb deleted file mode 100644 index cffef84af..000000000 --- a/nbs/explore_huggingface_datasets.ipynb +++ /dev/null @@ -1,1764 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Exploring the Extensibility of the 🤗 Datasets Library for Medical Images" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "import glob\n", - "import os\n", - "from functools import partial\n", - "from typing import Dict, List\n", - "\n", - "import dask\n", - "import dask.dataframe as dd\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "import PIL\n", - "import plotly.graph_objects as go\n", - "import psutil\n", - "import seaborn as sns\n", - "import torch\n", - "import torchxrayvision as xrv\n", - "import yaml\n", - "from datasets import Dataset, load_dataset\n", - "from datasets.features import ClassLabel, Image\n", - "from datasets.splits import Split\n", - "from monai.transforms import (\n", - " AddChanneld,\n", - " CenterSpatialCropd,\n", - " Compose,\n", - " Lambdad,\n", - " ToDeviced,\n", - ")\n", - "from omegaconf import OmegaConf\n", - "from sklearn.compose import ColumnTransformer\n", - "from sklearn.impute import SimpleImputer\n", - "from sklearn.pipeline import Pipeline\n", - "from sklearn.preprocessing import OneHotEncoder, StandardScaler\n", - "from torchvision.transforms import PILToTensor\n", - "from use_cases.params.mimiciv.mortality_decompensation.constants_v1 import (\n", - " ENCOUNTERS_FILE,\n", - " QUERIED_DIR,\n", - " TAB_FEATURES,\n", - ")\n", - "\n", - "from cyclops.data.slicer import SliceSpec\n", - "from cyclops.evaluate.metrics import MetricCollection, create_metric\n", - "from cyclops.models.catalog import create_model, list_models\n", - "from cyclops.models.constants import CONFIG_ROOT\n", - "from cyclops.utils.file import join" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# CONSTANTS\n", - "NUM_PROC = 4\n", - "TORCH_BATCH_SIZE = 64" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Exploring existing functionalities that are relevant to CyclOps" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Tabular Data" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Constructing a 🤗 Dataset from MIMICIV-v2.0 PostgreSQL Database" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "db_cfg = OmegaConf.load(join(\"..\", \"cyclops\", \"query\", \"configs\", \"config.yaml\"))\n", - "\n", - "con_str = (\n", - " db_cfg.dbms\n", - " + \"://\"\n", - " + db_cfg.user\n", - " + \":\"\n", - " + db_cfg.password\n", - " + \"@\"\n", - " + db_cfg.host\n", - " + \"/\"\n", - " + db_cfg.database\n", - ")\n", - "\n", - "ds = Dataset.from_sql(\n", - " sql=\"SELECT * FROM mimiciv_hosp.patients LIMIT 1000\",\n", - " con=con_str,\n", - " keep_in_memory=True,\n", - ")\n", - "ds" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Constructing a 🤗 Dataset from local parquet files" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "parquet_files = list(glob.glob(join(QUERIED_DIR, \"*.parquet\")))\n", - "len(parquet_files)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# take the first 300 files\n", - "parquet_files = parquet_files[:300]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "mimiciv_ds = load_dataset(\n", - " \"parquet\",\n", - " data_files=parquet_files,\n", - " split=Split.ALL,\n", - " num_proc=NUM_PROC,\n", - ")\n", - "\n", - "# clear all other cache files, except for the current cache file\n", - "mimiciv_ds.cleanup_cache_files()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "size_gb = mimiciv_ds.dataset_size / (1024**3)\n", - "print(f\"Dataset size (cache file) : {size_gb:.2f} GB\")\n", - "print(f\"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "mimiciv_ds.features" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### Benchmarking Filtering operations: 🤗 Dataset vs. Dask" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "dask.config.set(scheduler=\"processes\", num_workers=NUM_PROC)\n", - "\n", - "ddf = dd.read_parquet(parquet_files)\n", - "len(ddf)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "1. **Filtering on 1 column**\n", - "\n", - "Get all rows where the values in column `event_category` is in a list of values." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "event_filter = [\n", - " \"Cadiovascular\",\n", - " \"Dialysis\",\n", - " \"Hemodynamics\",\n", - " \"Neurological\",\n", - " \"Toxicology\",\n", - " \"General\",\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%timeit\n", - "events_ddf = ddf[ddf[\"event_category\"].isin(event_filter)].compute()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%timeit\n", - "events_ds = mimiciv_ds.filter(\n", - " lambda examples: [\n", - " example in event_filter for example in examples[\"event_category\"]\n", - " ],\n", - " batched=True,\n", - " num_proc=NUM_PROC,\n", - " load_from_cache_file=False, # timeit will run multiple times\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "2. **Filtering on multiple columns**\n", - "\n", - "Get all items where the values in two columns are in a list of values for each column." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "discharge_location_filter = [\"HOME\", \"HOME HEALTH CARE\"]\n", - "admission_location_filter = [\n", - " \"TRANSFER FROM HOSPITAL\",\n", - " \"PHYSICIAN REFERRAL\",\n", - " \"CLINIC REFERRAL\",\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%timeit\n", - "\n", - "location_ddf = ddf[\n", - " (ddf[\"discharge_location\"].isin(discharge_location_filter))\n", - " & (ddf[\"admission_location\"].isin(admission_location_filter))\n", - "].compute()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%timeit\n", - "\n", - "location_ds = mimiciv_ds.filter(\n", - " lambda examples: [\n", - " example[0] in discharge_location_filter\n", - " and example[1] in admission_location_filter\n", - " for example in zip(\n", - " examples[\"discharge_location\"],\n", - " examples[\"admission_location\"],\n", - " )\n", - " ],\n", - " batched=True,\n", - " num_proc=NUM_PROC,\n", - " load_from_cache_file=False, # timeit will run multiple times\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "3. **Filtering on a datetime condition**\n", - "\n", - "Get all rows where `date of death` occurred after January 1, 2020." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%timeit\n", - "dod_ddf = ddf[ddf[\"dod\"] > datetime.datetime(2020, 1, 1)].compute()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%timeit\n", - "\n", - "dod_ds = mimiciv_ds.filter(\n", - " lambda examples: [\n", - " example is not None and example > datetime.datetime(2020, 1, 1)\n", - " for example in examples[\"dod\"]\n", - " ],\n", - " batched=True,\n", - " num_proc=NUM_PROC,\n", - " load_from_cache_file=False, # timeit will run multiple times\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "4. **Filter on a condition on a column**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%timeit\n", - "millennials_ddf = ddf[(ddf.age <= 40) & (ddf.age >= 25)].compute()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%timeit\n", - "millennials_ds = mimiciv_ds.filter(\n", - " lambda examples: [25 <= example <= 40 for example in examples[\"age\"]],\n", - " batched=True,\n", - " num_proc=NUM_PROC,\n", - " load_from_cache_file=False, # timeit will run multiple times\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Image Data - Constructing a 🤗 Dataset from image folder" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "From the 🤗 Datasets documentation, there are 3 ways to load local image data into a 🤗 Dataset:\n", - "1. **Load images from a folder with the following structure:**\n", - " ```bash\n", - " root_folder/train/class1/img1.png\n", - " root_folder/train/class1/img2.png\n", - " root_folder/train/class2/img1.png\n", - " root_folder/train/class2/img2.png\n", - " root_folder/test/class1/img1.png\n", - " root_folder/test/class1/img2.png\n", - " root_folder/test/class2/img1.png\n", - " root_folder/test/class2/img2.png\n", - " ...\n", - " ```\n", - " The folder names are the class names and the dataset splits (train/test) will automatically be recognized.\n", - " The dataset can be loaded using the following code:\n", - " ```python\n", - " from datasets import load_dataset\n", - " dataset = load_dataset(\"imagefolder\", data_dir=\"root_folder\")\n", - " ```\n", - " (This method also supports loading remote image folders from URLs.)\n", - " \n", - " The downside of this approach is that it uses PIL to load the images, which does not support many medical image formats like DICOM and NIfTI.\n", - "\n", - "2. **Load images using a list of image paths**\n", - " ```python\n", - " from datasets import Dataset\n", - " from datasets.features import Image\n", - " dataset = Dataset.from_dict({\"image\": [\"path/to/img1.png\", \"path/to/img2.png\", ...]}).cast_column(\"image\", Image())\n", - " ```\n", - " This approach is more flexible than the previous one, but it still has the same limitation of not supporting many medical image formats.\n", - "\n", - "3. **Create a dataset loading script**\n", - "\n", - " This is the most flexible way to load and share different types of datasets that are not natively supported by 🤗 Datasets library.\n", - " In fact, the `imagefolder` dataset is an example of a dataset loading script. In essence, we can extend that script to support more image formats like DICOM and NIfTI. That solves half the problem. The other half is that we need to create a new feature to extend the `Image` class to support decoding medical image formats." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Case Study: MIMIC-CXR-JPG v2.0.0" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For this case study, we will combine CSV metadata and the `Image` feature to create a 🤗 Dataset from the MIMIC-CXR-JPG v2.0.0 dataset. The dataset is available on [PhysioNet](https://physionet.org/content/mimic-cxr-jpg/2.0.0/).\n", - "\n", - "The dataset comes with 4 compressed CSV metadata files. The metadata files are `mimic-cxr-2.0.0-split.csv.gz`, `mimic-cxr-2.0.0-chexpert.csv.gz`, `mimic-cxr-2.0.0-negbio.csv.gz`, and `mimic-cxr-2.0.0-metadata.csv.gz`. The `mimic-cxr-2.0.0-split.csv.gz` file contains the train/val/test split for each image. The `mimic-cxr-2.0.0-chexpert.csv.gz` file contains the CheXpert labels for each image. The `mimic-cxr-2.0.0-negbio.csv.gz` file contains the NegBio labels for each image. The `mimic-cxr-2.0.0-metadata.csv.gz` file contains other metadata for each image. All the metadata files can be joined on the `subject_id` and `study_id` columns." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "mimic_cxr_jpg_dir = \"/mnt/data/clinical_datasets/mimic-cxr-jpg-2.0.0\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# read metadata files using pandas\n", - "metadata_df = pd.read_csv(\n", - " os.path.join(mimic_cxr_jpg_dir, \"mimic-cxr-2.0.0-metadata.csv.gz\"),\n", - ")\n", - "negbio_df = pd.read_csv(\n", - " os.path.join(mimic_cxr_jpg_dir, \"mimic-cxr-2.0.0-negbio.csv.gz\"),\n", - ")\n", - "split_df = pd.read_csv(os.path.join(mimic_cxr_jpg_dir, \"mimic-cxr-2.0.0-split.csv.gz\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# join the 3 metadata dataframes on subject_id and study_id\n", - "metadata_df = metadata_df.merge(\n", - " split_df,\n", - " on=[\"subject_id\", \"study_id\", \"dicom_id\"],\n", - ").merge(negbio_df, on=[\"subject_id\", \"study_id\"])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# select rows with images in folder 'p10' i.e. subject_id starts with 10\n", - "metadata_df = metadata_df[metadata_df[\"subject_id\"].astype(str).str.startswith(\"10\")]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# create HuggingFace Dataset from pandas DataFrame\n", - "mimic_cxr_ds = Dataset.from_pandas(\n", - " metadata_df[metadata_df.split == \"train\"],\n", - " split=\"train\",\n", - " preserve_index=False,\n", - ")\n", - "mimic_cxr_ds" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# create a new column with the full path to the image:\n", - "# mimic_cxr_jpg_dir + \"p10\" + \"p\" + subject_id + study_id + dicom_id + \".jpg\"\n", - "\n", - "\n", - "def get_filename(examples):\n", - " subject_ids = examples[\"subject_id\"]\n", - " study_ids = examples[\"study_id\"]\n", - " dicom_ids = examples[\"dicom_id\"]\n", - " examples[\"image\"] = [\n", - " os.path.join(\n", - " mimic_cxr_jpg_dir,\n", - " \"files\",\n", - " \"p10\",\n", - " \"p\" + str(subject_id),\n", - " \"s\" + str(study_id),\n", - " dicom_id + \".jpg\",\n", - " )\n", - " for subject_id, study_id, dicom_id in zip(subject_ids, study_ids, dicom_ids)\n", - " ]\n", - " return examples\n", - "\n", - "\n", - "mimic_cxr_ds = mimic_cxr_ds.map(\n", - " get_filename,\n", - " batched=True,\n", - " num_proc=NUM_PROC,\n", - " remove_columns=[\"dicom_id\", \"split\", \"Rows\", \"Columns\"],\n", - ")\n", - "mimic_cxr_ds" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "mimic_cxr_ds = mimic_cxr_ds.cast_column(\"image\", Image())\n", - "mimic_cxr_ds.features" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "from cyclops.data.utils import set_decode # noqa: E402\n", - "\n", - "\n", - "set_decode(mimic_cxr_ds, decode=False)\n", - "mimic_cxr_ds[0][\"image\"]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "set_decode(dataset=mimic_cxr_ds, decode=True)\n", - "mimic_cxr_ds[0]" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Extending 🤗 Dataset to Load DICOM (and NIfTI) images" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "%matplotlib widget\n", - "\n", - "\n", - "# code for plotting 3D images\n", - "# Taken from: https://www.datacamp.com/tutorial/matplotlib-3d-volumetric-data\n", - "def multi_slice_viewer(volume):\n", - " fig, ax = plt.subplots()\n", - " ax.volume = volume\n", - " ax.index = volume.shape[0] // 2\n", - " ax.imshow(volume[ax.index], cmap=\"gray\")\n", - " fig.canvas.mpl_connect(\"key_press_event\", process_key)\n", - "\n", - "\n", - "def process_key(event):\n", - " fig = event.canvas.figure\n", - " ax = fig.axes[0]\n", - " if event.key == \"a\":\n", - " previous_slice(ax)\n", - " elif event.key == \"d\":\n", - " next_slice(ax)\n", - " fig.canvas.draw()\n", - "\n", - "\n", - "def previous_slice(ax):\n", - " \"\"\"Go to the previous slice.\"\"\"\n", - " volume = ax.volume\n", - " ax.index = (ax.index - 1) % volume.shape[0] # wrap around using %\n", - " ax.images[0].set_array(volume[ax.index])\n", - "\n", - "\n", - "def next_slice(ax):\n", - " \"\"\"Go to the next slice.\"\"\"\n", - " volume = ax.volume\n", - " ax.index = (ax.index + 1) % volume.shape[0]\n", - " ax.images[0].set_array(volume[ax.index])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "ROOT_DIR = \"/mnt/data/clinical_datasets/coherent-11-07-2022/dicom/\"\n", - "\n", - "dcm_files = glob.glob(ROOT_DIR + \"/**/*.dcm\", recursive=True)\n", - "len(dcm_files)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "1. Create a new feature class that extends the `Image` class to support decoding medical image formats. Let's call it `MedicalImage`. This will use MONAI to decode the medical image formats." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "from cyclops.data import MedicalImage # noqa: E402\n", - "\n", - "\n", - "# or" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "dicom_ds = Dataset.from_dict({\"image\": dcm_files}).cast_column(\"image\", MedicalImage())\n", - "print(\"Number of rows: \", dicom_ds.num_rows)\n", - "print(\"Features: \", dicom_ds.features)\n", - "print(\"Image column contents: \", list(dicom_ds[0][\"image\"].keys()))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "img = dicom_ds[0][\"image\"][\"array\"].shape" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "2. Create a new dataset loading script that extends the `imagefolder` dataset \n", - "loading script to support the `MedicalImage` feature class. We can call it \n", - "`medical_imagefolder`. \n", - "\n", - "For cyclops, the dataset loading script can be found in `cyclops/datasets/packaged_loading_scripts`.\n", - "Our new dataset loading script can be used with `load_dataset` by simply passing\n", - "the string `\"medical_imagefolder\"` to the `path` argument. This works because\n", - "we haved added the path to the script to huggingface's _PACKAGED_DATASETS_MODULES\n", - "registry in `cyclops/datasets/__init__.py`. This means that `cyclops.data`\n", - "must be imported for the script to be registered." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "med_ds = load_dataset(\"medicalimagefolder\", data_files=dcm_files, split=Split.ALL)\n", - "print(\"Number of rows: \", med_ds.num_rows)\n", - "print(\"Features: \", med_ds.features)\n", - "print(\"Image column contents: \", list(med_ds[0][\"image\"].keys()))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "med_img = med_ds[150][\"image\"][\"array\"]\n", - "multi_slice_viewer(med_img.T)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Some Challenges\n", - "\n", - "1. Handling metadata. What to do with it?\n", - "2. Encoding and decoding image bytes in the formats that are supported by the `MedicalImage` feature class." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Exploring Training and Evaluation of Scikit-Learn and PyTorch Models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "from cyclops.evaluate import evaluator # noqa: E402\n", - "from cyclops.evaluate.fairness import (\n", - " FairnessConfig, # noqa: E402\n", - " evaluate_fairness, # noqa: E402\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Scikit-Learn" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Data Loading" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "encounters_ds = load_dataset(\n", - " \"parquet\",\n", - " data_files=ENCOUNTERS_FILE,\n", - " split=Split.ALL,\n", - " keep_in_memory=True,\n", - ")\n", - "encounters_ds.cleanup_cache_files()\n", - "encounters_ds" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# split into train and test - 0.6, 0.4\n", - "# NOTE: train_test_split does not work with IterableDataset objects\n", - "encounters_ds = encounters_ds.cast_column(TAB_FEATURES[-1], ClassLabel(num_classes=2))\n", - "encounters_ds = encounters_ds.train_test_split(\n", - " test_size=0.4,\n", - " seed=42,\n", - " stratify_by_column=TAB_FEATURES[-1],\n", - ")\n", - "encounters_ds" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "TAB_FEATURES" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Pre-processing" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# pre-processing pipeline\n", - "numeric_features = [0] # ['age']\n", - "numeric_transformer = Pipeline(\n", - " steps=[(\"imputer\", SimpleImputer(strategy=\"median\")), (\"scaler\", StandardScaler())],\n", - ")\n", - "\n", - "categorical_features = [1, 2, 3] # ['sex', 'admission_type', 'admission_location']\n", - "categorical_transformer = OneHotEncoder(handle_unknown=\"ignore\")\n", - "\n", - "preprocessor = ColumnTransformer(\n", - " transformers=[\n", - " (\"num\", numeric_transformer, numeric_features),\n", - " (\"cat\", categorical_transformer, categorical_features),\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# get a count of the positive and negative samples\n", - "import pyarrow.compute as pc # noqa: E402\n", - "\n", - "\n", - "value_counts = pc.value_counts(encounters_ds[\"train\"]._data[TAB_FEATURES[-1]]).tolist()\n", - "pos_count = value_counts[1][\"counts\"]\n", - "neg_count = value_counts[0][\"counts\"]" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "model_dict = {}\n", - "\n", - "for model_name in list_models(\"sklearn\"):\n", - " if \"classifier\" not in model_name: # use only classifiers\n", - " continue\n", - "\n", - " # load the config file for the model\n", - " config_path = join(CONFIG_ROOT, model_name + \".yaml\")\n", - " with open(config_path, \"r\") as f:\n", - " cfg = yaml.safe_load(f)\n", - "\n", - " if model_name == \"xgb_classifier\":\n", - " # set the scale_pos_weight parameter to account for the class imbalance\n", - " cfg[\"scale_pos_weight\"] = neg_count / pos_count\n", - "\n", - " model_dict[model_name] = create_model(model_name, **cfg)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "for model_name, model in model_dict.items():\n", - " print(f\"Training {model_name}...\")\n", - " model_dict[model_name] = model.fit(\n", - " encounters_ds[\"train\"],\n", - " feature_columns=TAB_FEATURES[:-1],\n", - " target_columns=[TAB_FEATURES[-1]],\n", - " transforms=preprocessor,\n", - " )" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Evaluation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# specify some filters to apply to the dataset\n", - "slice_list = [\n", - " # remove null values in column\n", - " {\"dod\": {\"keep_nulls\": False}},\n", - " {\n", - " \"admission_type\": {\"keep_nulls\": True, \"negate\": True},\n", - " \"admission_location\": {\"keep_nulls\": False},\n", - " },\n", - " # filter by exact value\n", - " {\"sex\": {\"value\": \"M\"}},\n", - " # filter numeric values by range\n", - " {\n", - " \"age\": {\n", - " \"min_value\": 18,\n", - " \"max_value\": 65,\n", - " \"min_inclusive\": True,\n", - " \"max_inclusive\": False,\n", - " },\n", - " },\n", - " # filter by value in list\n", - " {\"admission_type\": {\"value\": [\"EW EMER.\", \"DIRECT EMER.\", \"URGENT\"]}},\n", - " # filter string values by substring\n", - " {\"admission_location\": {\"contains\": \"REFERRAL\"}},\n", - " # filter by date range (time string format: YYYY-MM-DD)\n", - " {\"dod\": {\"max_value\": \"2019-12-01\", \"keep_nulls\": True}},\n", - " # negate a filter\n", - " {\"dod\": {\"max_value\": \"2019-12-01\", \"negate\": True}},\n", - " # filter by month (1-12)\n", - " {\"admit_timestamp\": {\"month\": [6, 7, 8, 9], \"keep_nulls\": False}},\n", - " {\n", - " \"sex\": {\"value\": \"F\"},\n", - " \"race\": {\"contains\": [\"BLACK\", \"WHITE\"]},\n", - " \"age\": {\"min_value\": 25, \"max_value\": 40},\n", - " }, # compound slice\n", - "]\n", - "\n", - "# create the slice functions\n", - "slice_spec = SliceSpec()\n", - "for slice_ in slice_list:\n", - " slice_spec.add_slice_spec(slice_)\n", - "\n", - "# or" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# define the metrics\n", - "metric_names = [\"accuracy\", \"precision\", \"recall\", \"f1_score\", \"auroc\"]\n", - "metrics = [create_metric(metric_name, task=\"binary\") for metric_name in metric_names]\n", - "tab_metrics = MetricCollection(metrics)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "tab_eval_result = evaluator.evaluate(\n", - " encounters_ds,\n", - " tab_metrics,\n", - " split=\"test\",\n", - " models=model_dict,\n", - " transforms=preprocessor,\n", - " feature_columns=TAB_FEATURES[:-1],\n", - " target_columns=TAB_FEATURES[-1],\n", - " slice_spec=slice_spec,\n", - " batch_size=None, # load all data into memory\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# plot evaluation results\n", - "reformed_dict = {}\n", - "for outerKey, innerDict in tab_eval_result.items():\n", - " for innerKey, values in innerDict.items():\n", - " reformed_dict[(outerKey, innerKey)] = values\n", - "\n", - "tidy_df = pd.melt(\n", - " pd.DataFrame(reformed_dict).T.rename_axis([\"model\", \"slice\"]),\n", - " ignore_index=False,\n", - " var_name=\"metric\",\n", - ").reset_index()\n", - "\n", - "sns.catplot(\n", - " data=tidy_df,\n", - " x=\"slice\",\n", - " y=\"value\",\n", - " hue=\"model\",\n", - " row=\"slice\",\n", - " col=\"metric\",\n", - " kind=\"bar\",\n", - " sharey=True,\n", - " sharex=False,\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### Fairness" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "specificity = create_metric(metric_name=\"specificity\", task=\"binary\")\n", - "sensitivity = create_metric(metric_name=\"sensitivity\", task=\"binary\")\n", - "\n", - "fpr = 1 - specificity\n", - "fnr = 1 - sensitivity\n", - "ber = (fpr + fnr) / 2 # balanced error rate\n", - "\n", - "fairness_metric_collection = MetricCollection(\n", - " {\n", - " \"Sensitivity\": sensitivity,\n", - " \"Specificity\": specificity,\n", - " \"FPR\": fpr,\n", - " \"FNR\": fnr,\n", - " \"BER\": ber,\n", - " },\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "fairness_config = FairnessConfig(\n", - " metrics=fairness_metric_collection,\n", - " dataset=None, # dataset is passed from the evaluator\n", - " target_columns=None, # target columns are passed from the evaluator\n", - " groups=[\"sex\", \"age\"],\n", - " group_bins={\"age\": [26, 42, 58, 68]},\n", - " group_base_values={\"sex\": \"M\", \"age\": 40},\n", - " thresholds=[0.1, 0.5, 0.9],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "tab_model_analysis_results = evaluator.evaluate(\n", - " encounters_ds,\n", - " tab_metrics,\n", - " split=\"test\",\n", - " models=model_dict,\n", - " feature_columns=TAB_FEATURES[:-1],\n", - " target_columns=TAB_FEATURES[-1],\n", - " transforms=preprocessor,\n", - " slice_spec=slice_spec,\n", - " batch_size=-1, # use all examples at once\n", - " fairness_config=fairness_config,\n", - " override_fairness_metrics=False, # use separate metrics for evaluating fairness\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "reformed_fairness_dict = {}\n", - "for outerKey, innerDict in tab_model_analysis_results[\"fairness\"].items():\n", - " for innerKey, values in innerDict.items():\n", - " reformed_fairness_dict[(outerKey, innerKey)] = values\n", - "\n", - "tidy_fairness_df = pd.melt(\n", - " pd.DataFrame(reformed_fairness_dict).T.rename_axis([\"model\", \"slice\"]),\n", - " ignore_index=False,\n", - " var_name=\"metric\",\n", - ").reset_index()\n", - "\n", - "sns.catplot(\n", - " data=tidy_fairness_df,\n", - " x=\"slice\",\n", - " y=\"value\",\n", - " hue=\"model\",\n", - " row=\"metric\",\n", - " col=\"slice\",\n", - " kind=\"bar\",\n", - " sharey=False,\n", - " sharex=False,\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### PyTorch" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Data Loading" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "def nihcxr_preprocess(df: pd.DataFrame, nihcxr_dir: str) -> pd.DataFrame:\n", - " \"\"\"Preprocess NIHCXR dataframe.\n", - "\n", - " Add a column with the path to the image and create one-hot encoded pathogies\n", - " from Finding Labels column.\n", - "\n", - " Args:\n", - " ----\n", - " df (pd.DataFrame): NIHCXR dataframe.\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame: pre-processed NIHCXR dataframe.\n", - " \"\"\"\n", - " # Add path column\n", - " df[\"image\"] = df[\"Image Index\"].apply(\n", - " lambda x: os.path.join(nihcxr_dir, \"images\", x),\n", - " )\n", - "\n", - " # Create one-hot encoded pathologies\n", - " pathologies = df[\"Finding Labels\"].str.get_dummies(sep=\"|\")\n", - "\n", - " # Add one-hot encoded pathologies to dataframe\n", - " return pd.concat([df, pathologies], axis=1)\n", - "\n", - "\n", - "nihcxr_dir = \"/mnt/data/clinical_datasets/NIHCXR\"\n", - "\n", - "test_df = pd.read_csv(\n", - " join(nihcxr_dir, \"test_list.txt\"),\n", - " header=None,\n", - " names=[\"Image Index\"],\n", - ")\n", - "\n", - "# select only the images in the test list\n", - "df = pd.read_csv(join(nihcxr_dir, \"Data_Entry_2017.csv\"))\n", - "df.dropna(how=\"all\", axis=\"columns\", inplace=True) # drop empty columns\n", - "df = df[df[\"Image Index\"].isin(test_df[\"Image Index\"])]\n", - "\n", - "df = nihcxr_preprocess(df, nihcxr_dir)\n", - "\n", - "# create a Dataset object\n", - "nih_ds = Dataset.from_pandas(df, preserve_index=False)\n", - "nih_ds = nih_ds.cast_column(\"image\", Image())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "nih_ds.features" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "print(device)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Pre-processing" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "transforms = Compose(\n", - " [\n", - " # TorchVisiond(keys=(\"image\",), name=\"PILToTensor\"), doesn't work\n", - " AddChanneld(keys=(\"image\",)),\n", - " CenterSpatialCropd(keys=(\"image\",), roi_size=(1, 224, 224)),\n", - " Lambdad(keys=(\"image\"), func=lambda x: ((2 * (x / 255.0)) - 1.0) * 1024),\n", - " ToDeviced(keys=(\"image\",), device=device),\n", - " ],\n", - ")\n", - "\n", - "\n", - "def apply_transforms(examples: Dict[str, List], transforms: callable) -> dict:\n", - " \"\"\"Apply transforms to examples.\"\"\"\n", - " # examples is a dict of lists; convert to list of dicts.\n", - " # doing a conversion from PIL to tensor is necessary here when working\n", - " # with the Image feature type.\n", - " value_len = len(list(examples.values())[0])\n", - " examples = [\n", - " {\n", - " k: PILToTensor()(v[i]) if isinstance(v[i], PIL.Image.Image) else v[i]\n", - " for k, v in examples.items()\n", - " }\n", - " for i in range(value_len)\n", - " ]\n", - "\n", - " # apply the transforms to each example\n", - " examples = [transforms(example) for example in examples]\n", - "\n", - " # convert back to a dict of lists\n", - " return {k: [d[k] for d in examples] for k in examples[0]}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# nih_ds.with_transform(\n", - "# ),\n", - "\n", - "# for batch in nih_dl:" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Prediction" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "model = xrv.models.DenseNet(weights=\"densenet121-res224-nih\")\n", - "model.eval()\n", - "model.to(device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "from datasets.combine import concatenate_datasets # noqa: E402\n", - "\n", - "\n", - "def get_predictions_torch(examples):\n", - " images = torch.stack(examples[\"image\"]).squeeze(1)\n", - " preds = model(images)\n", - " return {\"predictions\": preds}\n", - "\n", - "\n", - "with nih_ds.formatted_as(\n", - " \"custom\",\n", - " columns=[\"image\"],\n", - " transform=partial(apply_transforms, transforms=transforms),\n", - "):\n", - " preds_ds = nih_ds.map(\n", - " get_predictions_torch,\n", - " batched=True,\n", - " batch_size=TORCH_BATCH_SIZE,\n", - " remove_columns=nih_ds.column_names,\n", - " )\n", - "\n", - " nih_ds = concatenate_datasets([nih_ds, preds_ds], axis=1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "nih_ds.features" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "from cyclops.data.slicer import filter_value # noqa: E402\n", - "\n", - "\n", - "# remove any rows with No Finding == 1\n", - "nih_ds = nih_ds.filter(\n", - " partial(filter_value, column_name=\"No Finding\", value=1, negate=True),\n", - " batched=True,\n", - ")\n", - "\n", - "# remove the No Finding column and adjust the predictions to account for it\n", - "nih_ds = nih_ds.map(\n", - " lambda x: {\n", - " \"predictions\": x[\"predictions\"][:14],\n", - " },\n", - " remove_columns=[\"No Finding\"],\n", - ")\n", - "nih_ds.features" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# get the list of pathologies\n", - "pathologies = model.pathologies[:14]\n", - "pathologies" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Evaluation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# define the slices\n", - "slices = [\n", - " {\"Patient Gender\": {\"value\": \"M\"}},\n", - " {\"Patient Age\": {\"min_value\": 20, \"max_value\": 40}},\n", - "]\n", - "\n", - "# create the slice functions\n", - "slice_spec = SliceSpec(spec_list=slices)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "auroc = create_metric(\n", - " metric_name=\"auroc\",\n", - " task=\"multilabel\",\n", - " num_labels=len(pathologies),\n", - " thresholds=np.arange(0, 1, 0.01),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "nih_eval_results = evaluator.evaluate(\n", - " dataset=nih_ds,\n", - " metrics=auroc,\n", - " feature_columns=\"image\",\n", - " target_columns=pathologies,\n", - " prediction_column_prefix=\"predictions\",\n", - " remove_columns=\"image\",\n", - " slice_spec=slice_spec,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# plot the results\n", - "plots = []\n", - "\n", - "for slice_name, slice_results in nih_eval_results.items():\n", - " plots.append(\n", - " go.Scatter(\n", - " x=pathologies,\n", - " y=slice_results[\"MultilabelAUROC\"],\n", - " name=\"Overall\" if slice_name == \"overall\" else slice_name,\n", - " mode=\"markers\",\n", - " ),\n", - " )\n", - "\n", - "fig = go.Figure(data=plots)\n", - "fig.update_layout(\n", - " title=\"Multilabel AUROC by Pathology and Slice\",\n", - " title_x=0.5,\n", - " title_font_size=20,\n", - " xaxis_title=\"Pathology\",\n", - " yaxis_title=\"Multilabel AUROC\",\n", - ")\n", - "fig.update_traces(\n", - " marker={\"size\": 12, \"line\": {\"width\": 2, \"color\": \"DarkSlateGrey\"}},\n", - " selector={\"mode\": \"markers\"},\n", - ")\n", - "fig.show()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Fairness" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "specificity = create_metric(\n", - " metric_name=\"specificity\",\n", - " task=\"multilabel\",\n", - " num_labels=len(pathologies),\n", - ")\n", - "sensitivity = create_metric(\n", - " metric_name=\"sensitivity\",\n", - " task=\"multilabel\",\n", - " num_labels=len(pathologies),\n", - ")\n", - "\n", - "fpr = 1 - specificity\n", - "fnr = 1 - sensitivity\n", - "\n", - "balanced_error_rate = (fpr + fnr) / 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "nih_fairness_result = evaluate_fairness(\n", - " metrics=balanced_error_rate,\n", - " metric_name=\"BalancedErrorRate\",\n", - " dataset=nih_ds,\n", - " remove_columns=\"image\",\n", - " target_columns=pathologies,\n", - " prediction_columns=\"predictions\",\n", - " groups=[\"Patient Age\", \"Patient Gender\"],\n", - " group_bins={\"Patient Age\": [20, 40, 60, 80]},\n", - " group_base_values={\"Patient Age\": 20, \"Patient Gender\": \"M\"},\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### Plots" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# plot group size per slice\n", - "plots = []\n", - "\n", - "for slice_name, slice_results in nih_fairness_result.items():\n", - " plots.append(\n", - " go.Bar(\n", - " x=[slice_name],\n", - " y=[slice_results[\"Group Size\"]],\n", - " name=slice_name,\n", - " ),\n", - " )\n", - "\n", - "fig = go.Figure(data=plots)\n", - "fig.update_layout(\n", - " title=\"Size of Each Group\",\n", - " title_x=0.5,\n", - " title_font_size=20,\n", - " xaxis_title=\"Group\",\n", - " yaxis_title=\"Group Size\",\n", - " showlegend=False,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# plot metrics per slice\n", - "plots = []\n", - "\n", - "for slice_name, slice_results in nih_fairness_result.items():\n", - " plots.append(\n", - " go.Scatter(\n", - " x=pathologies,\n", - " y=slice_results[\"BalancedErrorRate\"],\n", - " name=slice_name,\n", - " mode=\"markers\",\n", - " ),\n", - " )\n", - "\n", - "fig = go.Figure(data=plots)\n", - "fig.update_layout(\n", - " title=\"Balanced Error Rate by Pathology and Group\",\n", - " title_x=0.5,\n", - " title_font_size=20,\n", - " xaxis_title=\"Pathology\",\n", - " yaxis_title=\"Balanced Error Rate\",\n", - ")\n", - "fig.update_traces(\n", - " marker={\"size\": 12, \"line\": {\"width\": 2, \"color\": \"DarkSlateGrey\"}},\n", - " selector={\"mode\": \"markers\"},\n", - ")\n", - "fig.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "# plot parity difference per slice\n", - "plots = []\n", - "\n", - "for slice_name, slice_results in nih_fairness_result.items():\n", - " plots.append(\n", - " go.Scatter(\n", - " x=pathologies,\n", - " y=slice_results[\"BalancedErrorRate Parity\"],\n", - " name=slice_name,\n", - " mode=\"markers\",\n", - " ),\n", - " )\n", - "\n", - "fig = go.Figure(data=plots)\n", - "fig.update_layout(\n", - " title=\"Balanced Error Rate Parity by Pathology and Group\",\n", - " title_x=0.5,\n", - " title_font_size=20,\n", - " xaxis_title=\"Pathology\",\n", - " yaxis_title=\"Balanced Error Rate Parity\",\n", - ")\n", - "fig.update_traces(\n", - " marker={\"size\": 12, \"line\": {\"width\": 2, \"color\": \"DarkSlateGrey\"}},\n", - " selector={\"mode\": \"markers\"},\n", - ")\n", - "fig.show()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### Alternative" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [ - "fairness_config = FairnessConfig(\n", - " metrics=balanced_error_rate,\n", - " metric_name=\"BalancedErrorRate\",\n", - " dataset=None, # dataset is passed from the evaluator\n", - " target_columns=None, # target columns are passed from the evaluator\n", - " groups=[\"Patient Age\", \"Patient Gender\"],\n", - " group_bins={\"Patient Age\": [20, 40, 60, 80]},\n", - " group_base_values={\"Patient Age\": 20, \"Patient Gender\": \"M\"},\n", - ")\n", - "\n", - "evaluator.evaluate(\n", - " dataset=nih_ds,\n", - " metrics=auroc,\n", - " target_columns=pathologies,\n", - " slice_spec=slice_spec,\n", - " remove_columns=[\"image\"],\n", - " fairness_config=fairness_config,\n", - " override_fairness_metrics=False,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "cyclops", - "language": "python", - "name": "cyclops" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/nbs/monitor/gemini_drift_experiments/clinical_drift.ipynb b/nbs/monitor/gemini_drift_experiments/clinical_drift.ipynb deleted file mode 100644 index dad9fd2ea..000000000 --- a/nbs/monitor/gemini_drift_experiments/clinical_drift.ipynb +++ /dev/null @@ -1,353 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "7ef2b202-6443-4897-9a0c-f4b63392f851", - "metadata": {}, - "source": [ - "### Clinical Drift Detection" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "24df23a5-e1b9-4ce2-bc5a-55965a979b6a", - "metadata": {}, - "outputs": [], - "source": [ - "import datetime\n", - "\n", - "from drift_detector.clinical_applicator import ClinicalShiftApplicator\n", - "from drift_detector.detector import Detector\n", - "from drift_detector.experimenter import Experimenter\n", - "from drift_detector.plotter import plot_drift_samples_pval\n", - "from drift_detector.reductor import Reductor\n", - "from drift_detector.tester import DCTester, TSTester\n", - "from gemini.query import get_gemini_data\n", - "from gemini.utils import get_label, import_dataset_hospital, normalize, process, scale" - ] - }, - { - "cell_type": "markdown", - "id": "renewable-mortgage", - "metadata": {}, - "source": [ - "## Config Parameters ##" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "hourly-insider", - "metadata": {}, - "outputs": [], - "source": [ - "PATH = \"/mnt/nfs/project/delirium/drift_exp/JULY-04-2022/\"\n", - "HOSPITALS = [\"SMH\", \"MSH\", \"THPC\", \"THPM\", \"UHNTG\", \"UHNTW\", \"PMH\", \"SBK\"]\n", - "TIMESTEPS = 6\n", - "AGGREGATION_TYPE = \"time\"\n", - "ACADEMIC = [\"MSH\", \"PMH\", \"SMH\", \"UHNTW\", \"UHNTG\", \"PMH\", \"SBK\"]\n", - "COMMUNITY = [\"THPC\", \"THPM\"]\n", - "\n", - "OUTCOME = input(\"Select outcome variable: \")\n", - "SHIFT = input(\"Select experiment: \")\n", - "MODEL_PATH = (\n", - " \"/mnt/nfs/project/delirium/drift_exp/JULY-04-2022/saved_models/\"\n", - " + SHIFT\n", - " + \"_lstm.pt\"\n", - ")\n", - "\n", - "if SHIFT == \"simulated_deployment\":\n", - " exp_params = {\n", - " \"source\": [datetime.date(2015, 1, 1), datetime.date(2019, 1, 1)],\n", - " \"target\": [datetime.date(2019, 1, 1), datetime.date(2020, 8, 1)],\n", - " \"shift_type\": \"source_target\",\n", - " }\n", - "\n", - "if SHIFT == \"covid\":\n", - " exp_params = {\n", - " \"source\": [datetime.date(2019, 1, 1), datetime.date(2020, 2, 1)],\n", - " \"target\": [datetime.date(2020, 3, 1), datetime.date(2020, 8, 1)],\n", - " \"shift_type\": \"time\",\n", - " }\n", - "\n", - "if SHIFT == \"seasonal_summer\":\n", - " exp_params = {\n", - " \"source\": [1, 2, 3, 4, 5, 10, 11, 12],\n", - " \"target\": [6, 7, 8, 9],\n", - " \"shift_type\": \"month\",\n", - " }\n", - "\n", - "if SHIFT == \"seasonal_winter\":\n", - " exp_params = {\n", - " \"source\": [3, 4, 5, 6, 7, 8, 9, 10],\n", - " \"target\": [11, 12, 1, 2],\n", - " \"shift_type\": \"month\",\n", - " }\n", - "\n", - "if SHIFT == \"hosp_type_academic\":\n", - " exp_params = {\n", - " \"source\": ACADEMIC,\n", - " \"target\": COMMUNITY,\n", - " \"shift_type\": \"hospital_type\",\n", - " }\n", - "\n", - "if SHIFT == \"hosp_type_community\":\n", - " exp_params = {\n", - " \"source\": COMMUNITY,\n", - " \"target\": ACADEMIC,\n", - " \"shift_type\": \"hospital_type\",\n", - " }" - ] - }, - { - "cell_type": "markdown", - "id": "ee047a8c-392e-46c5-8877-37aa9ca4fc15", - "metadata": {}, - "source": [ - "## Query Data ##" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7651e075-3673-44ec-9196-bf9bee056cfd", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "admin_data, x, y = get_gemini_data(PATH)\n", - "\n", - "\n", - "(X_tr, y_tr), (X_val, y_val), (X_t, y_t), feats, admin_data = import_dataset_hospital(\n", - " admin_data,\n", - " x,\n", - " y,\n", - " SHIFT,\n", - " OUTCOME,\n", - " HOSPITALS,\n", - ")\n", - "\n", - "x = x.loc[~x.index.get_level_values(0).isin(X_tr.index.get_level_values(0))]\n", - "\n", - "# Normalize training data\n", - "X_tr_normalized = normalize(admin_data, X_tr, AGGREGATION_TYPE, TIMESTEPS)\n", - "if AGGREGATION_TYPE != \"time\":\n", - " y_tr = get_label(admin_data, X_tr, OUTCOME)\n", - "# Scale training data\n", - "X_tr_scaled = scale(X_tr_normalized)\n", - "# Process training data\n", - "X_tr_final = process(X_tr_scaled, AGGREGATION_TYPE, TIMESTEPS)" - ] - }, - { - "cell_type": "markdown", - "id": "d0e01447-bc6a-429b-99a6-cdc83b91153f", - "metadata": { - "tags": [] - }, - "source": [ - "## Reductor" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c386e22c-2f8a-4b5c-89c9-d4211bd11bea", - "metadata": {}, - "outputs": [], - "source": [ - "DR_TECHNIQUE = input(\"Select dimensionality reduction technique: \")\n", - "\n", - "reductor = Reductor(\n", - " dr_method=DR_TECHNIQUE,\n", - " model_path=MODEL_PATH,\n", - " n_features=len(feats),\n", - " var_ret=0.8,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "17fae68e-7d4e-4a6b-900b-bcd2e262aca3", - "metadata": {}, - "source": [ - "## Tester" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "52851a31-9b6e-4d5b-9e23-52f83b6b4639", - "metadata": {}, - "outputs": [], - "source": [ - "TESTER_METHOD = input(\"Select test method: \")\n", - "tstesters = [\"lk\", \"lsdd\", \"mmd\", \"tabular\", \"ctx_mmd\", \"chi2\", \"fet\", \"ks\"]\n", - "dctesters = [\"spot_the_diff\", \"classifier\", \"classifier_uncertainty\"]\n", - "\n", - "if TESTER_METHOD in tstesters:\n", - " tester = TSTester(\n", - " tester_method=TESTER_METHOD,\n", - " )\n", - "elif TESTER_METHOD in dctesters:\n", - " MODEL_METHOD = input(\"Select model method: \")\n", - " tester = DCTester(\n", - " tester_method=TESTER_METHOD,\n", - " model_method=MODEL_METHOD,\n", - " )\n", - "\n", - " if MODEL_METHOD == \"ctx_mmd\":\n", - " CONTEXT_TYPE = input(\"Select context type: \")\n", - "\n", - " if MODEL_METHOD == \"lk\":\n", - " REPRESENTATION = input(\"Select learned kernel representation: \")" - ] - }, - { - "cell_type": "markdown", - "id": "ad11d8ca-b171-4f8e-a715-36d73f1e1f0d", - "metadata": {}, - "source": [ - "## Detector" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c2b71002-f587-4f44-96b5-9437c2d27ee6", - "metadata": {}, - "outputs": [], - "source": [ - "detector = Detector(reductor=reductor, tester=tester, p_val_threshold=0.05)\n", - "detector.fit(\n", - " X_tr_final,\n", - " model_path=MODEL_PATH,\n", - " context_type=CONTEXT_TYPE,\n", - " alternative=\"two-sided\",\n", - " n_permutations=100,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "87a291d7-2e69-44ef-884c-d40bc750d396", - "metadata": {}, - "source": [ - "## ClinicalShiftApplicator" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c6ecf217-e948-4d98-889a-25f5f867c100", - "metadata": {}, - "outputs": [], - "source": [ - "clinicalshiftapplicator = ClinicalShiftApplicator(shift_type=exp_params[\"shift_type\"])\n", - "\n", - "experimenter = Experimenter(\n", - " detector=detector,\n", - " clinicalshiftapplicator=clinicalshiftapplicator,\n", - " admin_data=admin_data,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "8c4299b2-51dc-4ab1-ac5b-d3de6cf601a7", - "metadata": {}, - "source": [ - "## Experimenter" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b53edc1b-14ea-43da-a6ef-0f4a6a02e9d5", - "metadata": {}, - "outputs": [], - "source": [ - "X_val, X_t = experimenter.apply_clinical_shift(\n", - " x,\n", - " source=exp_params[\"source\"],\n", - " target=exp_params[\"target\"],\n", - ")\n", - "# Normalize data\n", - "X_val_normalized = normalize(admin_data, X_val, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_normalized = normalize(admin_data, X_t, AGGREGATION_TYPE, TIMESTEPS)\n", - "\n", - "if AGGREGATION_TYPE != \"time\":\n", - " y_val = get_label(admin_data, X_val, OUTCOME)\n", - " y_t = get_label(admin_data, X_t, OUTCOME)\n", - "\n", - "# Scale data\n", - "X_val_scaled = scale(X_val_normalized)\n", - "X_t_scaled = scale(X_t_normalized)\n", - "\n", - "# Process data\n", - "X_val_final = process(X_val_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_final = process(X_t_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "\n", - "val_drift_results = experimenter.detect_shift_samples(\n", - " X_val_final,\n", - " model_path=MODEL_PATH,\n", - " context_type=CONTEXT_TYPE,\n", - " n_permutations=100,\n", - ")\n", - "test_drift_results = experimenter.detect_shift_samples(\n", - " X_t_final,\n", - " model_path=MODEL_PATH,\n", - " context_type=CONTEXT_TYPE,\n", - " n_permutations=100,\n", - ")\n", - "shift_results = {\"baseline\": val_drift_results, \"experiment\": test_drift_results}" - ] - }, - { - "cell_type": "markdown", - "id": "607d5cce-a1e4-4f81-8899-093edf72b7f6", - "metadata": {}, - "source": [ - "## Plot drift results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bab2ee81-c0ec-4fcc-baaa-e76406b35021", - "metadata": {}, - "outputs": [], - "source": [ - "plot_drift_samples_pval(shift_results, 0.05)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "cyclops-KKtuQLwg-py3.9", - "language": "python", - "name": "cyclops-kktuqlwg-py3.9" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - }, - "vscode": { - "interpreter": { - "hash": "bd2cd438e1c6ddffa3035fc73b17ac5cc0e0ea8897eb8be17cc645c6abf0c8cc" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/nbs/monitor/gemini_drift_experiments/explainer.ipynb b/nbs/monitor/gemini_drift_experiments/explainer.ipynb deleted file mode 100644 index 12cba8abf..000000000 --- a/nbs/monitor/gemini_drift_experiments/explainer.ipynb +++ /dev/null @@ -1,247 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "24922ac6-09bb-4008-a91e-5fa321999a77", - "metadata": {}, - "source": [ - "### Explainability API ## " - ] - }, - { - "cell_type": "markdown", - "id": "fa800630-7ed1-4ac4-ac15-6e39dbb68cb7", - "metadata": {}, - "source": [ - "## Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b841996e-826f-430e-a831-c35dd6658b09", - "metadata": {}, - "outputs": [], - "source": [ - "import itertools\n", - "import os\n", - "import pickle\n", - "import random\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "from baseline_models.static.utils import run_model\n", - "from drift_detector.explainer import Explainer\n", - "from gemini.query import get_gemini_data\n", - "from gemini.utils import get_label, import_dataset_hospital, normalize, process, scale" - ] - }, - { - "cell_type": "markdown", - "id": "741632c0-aa41-4553-bbb2-eb6fee0b083a", - "metadata": {}, - "source": [ - "## Get data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2f887bac-e8f8-4850-970d-b83420bae587", - "metadata": {}, - "outputs": [], - "source": [ - "SHIFT = input(\"Select experiment: \")\n", - "HOSPITALS = [\"SMH\", \"MSH\", \"THPC\", \"THPM\", \"UHNTG\", \"UHNTW\", \"PMH\", \"SBK\"]\n", - "OUTCOME = \"mortality\"\n", - "PATH = \"/mnt/nfs/project/delirium/drift_exp/JULY-04-2022/\"\n", - "MODEL_PATH = os.path.join(PATH, \"saved_models\", SHIFT + \"_lstm.pt\")\n", - "TIMESTEPS = 6\n", - "AGGREGATION_TYPE = \"time_flatten\"\n", - "THRESHOLD = 0.05\n", - "NUM_TIMESTEPS = 6\n", - "\n", - "random.seed(1)\n", - "\n", - "admin_data, x, y = get_gemini_data(PATH)\n", - "\n", - "(X_tr, y_tr), (X_val, y_val), (X_t, y_t), feats, admin_data = import_dataset_hospital(\n", - " admin_data,\n", - " x,\n", - " y,\n", - " SHIFT,\n", - " OUTCOME,\n", - " HOSPITALS,\n", - ")\n", - "\n", - "# Normalize data\n", - "X_tr_normalized = normalize(admin_data, X_tr, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_val_normalized = normalize(admin_data, X_val, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_normalized = normalize(admin_data, X_t, AGGREGATION_TYPE, TIMESTEPS)\n", - "\n", - "if AGGREGATION_TYPE != \"time\":\n", - " # Get labels\n", - " y_tr = get_label(admin_data, X_tr, OUTCOME)\n", - " y_val = get_label(admin_data, X_val, OUTCOME)\n", - " y_t = get_label(admin_data, X_t, OUTCOME)\n", - "\n", - "# Scale data\n", - "X_tr_scaled = scale(X_tr_normalized)\n", - "X_val_scaled = scale(X_val_normalized)\n", - "X_t_scaled = scale(X_t_normalized)\n", - "\n", - "# Process data\n", - "X_tr_final = process(X_tr_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_val_final = process(X_val_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_final = process(X_t_scaled, AGGREGATION_TYPE, TIMESTEPS)" - ] - }, - { - "cell_type": "markdown", - "id": "8032f5f1-84cf-4f1c-9a53-de86dc490e09", - "metadata": {}, - "source": [ - "## Get model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0b46bcac-41cb-4326-a7f2-2cdc7f8f740f", - "metadata": {}, - "outputs": [], - "source": [ - "MODEL_NAME = input(\"Select Model: \")\n", - "MODEL_PATH = PATH + \"_\".join([SHIFT, OUTCOME, \"_\".join(HOSPITALS), MODEL_NAME]) + \".pkl\"\n", - "if os.path.exists(MODEL_PATH):\n", - " optimised_model = pickle.load(open(MODEL_PATH, \"rb\"))\n", - "else:\n", - " optimised_model = run_model(MODEL_NAME, X_tr_final, y_tr, X_val_final, y_val)\n", - " pickle.dump(optimised_model, open(MODEL_PATH, \"wb\"))" - ] - }, - { - "cell_type": "markdown", - "id": "c7697a7b-7aa6-4ffc-8dd7-921474bcdf31", - "metadata": {}, - "source": [ - "## Explain difference in model predictions ## " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1e2af7f8-5f0c-4cd5-b7f2-7ab07662e7ee", - "metadata": {}, - "outputs": [], - "source": [ - "explainer = Explainer(optimised_model, X_tr_final)\n", - "explainer.get_explainer()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d1b809f2-d434-4905-8fc6-7489057de217", - "metadata": {}, - "outputs": [], - "source": [ - "timesteps = [\"T1_\", \"T2_\", \"T3_\", \"T4_\", \"T5_\", \"T6_\"]\n", - "\n", - "flattened_feats = []\n", - "for ts in timesteps:\n", - " flattened_feats.append(ts + feats)\n", - "flattened_feats = list(itertools.chain.from_iterable(flattened_feats))\n", - "\n", - "X_val_df = pd.DataFrame(X_val_final, columns=flattened_feats)\n", - "val_shap_values = explainer.get_shap_values(X_val_df)\n", - "X_test_df = pd.DataFrame(X_t_final, columns=flattened_feats)\n", - "test_shap_values = explainer.get_shap_values(X_test_df)\n", - "\n", - "shap_diff = np.mean(np.abs(test_shap_values.values), axis=0) - np.mean(\n", - " np.abs(val_shap_values.values),\n", - " axis=0,\n", - ")\n", - "shap_min = -0.001\n", - "shap_max = 0.001\n", - "shap_diff_sorted, feats_sorted = zip(\n", - " *sorted(zip(shap_diff, flattened_feats), reverse=True),\n", - ")\n", - "shap_diff_sorted, feats_sorted = zip(\n", - " *(\n", - " (\n", - " (x, y)\n", - " for x, y in zip(shap_diff_sorted, feats_sorted)\n", - " if (x > shap_max or x < shap_min)\n", - " )\n", - " ),\n", - ")\n", - "\n", - "shap_feats = {\"feature\": feats_sorted, \"shap_diff\": list(shap_diff_sorted)}\n", - "\n", - "fig, ax = plt.subplots(figsize=(10, 18))\n", - "y_pos = np.arange(len(shap_feats[\"shap_diff\"]))\n", - "ax.barh(y_pos, shap_feats[\"shap_diff\"], align=\"center\")\n", - "ax.set_yticks(y_pos, labels=shap_feats[\"feature\"])\n", - "ax.invert_yaxis() # labels read top-to-bottom\n", - "ax.set_xlabel(\"Mean Difference in Shap Value\")\n", - "ax.set_title(\"Features\")\n", - "plt.show()\n", - "\n", - "shap_diff_sorted, feats_sorted = zip(\n", - " *sorted(zip(shap_diff, flattened_feats), reverse=True),\n", - ")\n", - "shap_diff_sorted, feats_sorted = zip(\n", - " *(((x, y) for x, y in zip(shap_diff_sorted, feats_sorted) if (x != 0))),\n", - ")\n", - "\n", - "for t in [\"T1_\", \"T2_\", \"T4_\", \"T4_\", \"T5_\", \"T6_\"]:\n", - " shap_feats = {\"feature\": feats_sorted, \"shap_diff\": list(shap_diff_sorted)}\n", - " shap_feats = {\n", - " k: [\n", - " x\n", - " for i, x in enumerate(v)\n", - " if any(ts in shap_feats[\"feature\"][i] for ts in [t])\n", - " ]\n", - " for k, v in shap_feats.items()\n", - " }\n", - " shap_feats[\"feature\"] = [x.replace(t, \"\") for x in shap_feats[\"feature\"]]\n", - " fig, ax = plt.subplots(figsize=(12, 12))\n", - " y_pos = np.arange(len(shap_feats[\"shap_diff\"]))\n", - " ax.barh(y_pos, shap_feats[\"shap_diff\"], align=\"center\")\n", - " ax.set_yticks(y_pos, labels=shap_feats[\"feature\"])\n", - " ax.invert_yaxis() # labels read top-to-bottom\n", - " ax.set_xlabel(\"Mean Difference in Shap Value |Target - Source|\")\n", - " ax.set_title(\"Features\")\n", - " plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.4" - }, - "vscode": { - "interpreter": { - "hash": "bd2cd438e1c6ddffa3035fc73b17ac5cc0e0ea8897eb8be17cc645c6abf0c8cc" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/nbs/monitor/gemini_drift_experiments/feature_shift.ipynb b/nbs/monitor/gemini_drift_experiments/feature_shift.ipynb deleted file mode 100644 index 3f5992978..000000000 --- a/nbs/monitor/gemini_drift_experiments/feature_shift.ipynb +++ /dev/null @@ -1,283 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "informed-tension", - "metadata": {}, - "outputs": [], - "source": [ - "import pickle\n", - "from os import path\n", - "from time import localtime, strftime, time\n", - "\n", - "import numpy as np\n", - "import torch\n", - "from drift_detection.fsd import FeatureShiftDetector\n", - "from gemini.utils import import_dataset_hospital\n", - "from sklearn.metrics import confusion_matrix as sklearn_confusion_matrix" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "intense-hampton", - "metadata": {}, - "outputs": [], - "source": [ - "HOSPITAL = [\"SMH\", \"MSH\", \"THPC\", \"THPM\", \"UHNTG\", \"UHNTW\", \"PMH\", \"SBK\"]\n", - "NA_CUTOFF = 0.6\n", - "SHIFT_EXPERIMENT = input(\"Select experiment: \")\n", - "OUTCOME = input(\"Select outcome variable: \")\n", - "\n", - "(\n", - " (X_train, y_train),\n", - " (X_val, y_val),\n", - " (X_test, y_test),\n", - " feats,\n", - " orig_dims,\n", - ") = import_dataset_hospital(\n", - " SHIFT_EXPERIMENT,\n", - " OUTCOME,\n", - " HOSPITAL,\n", - " NA_CUTOFF,\n", - " shuffle=True,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "indoor-harvard", - "metadata": {}, - "outputs": [], - "source": [ - "# # Global Experiment Parameters\n", - "n_samples = 100 # The number of samples in p, q (thus n_samples_total = n_samples*2)\n", - "n_bootstrap_runs = 50\n", - "n_conditional_expectation = 30\n", - "n_inner_expectation = n_conditional_expectation\n", - "alpha = 0.05 # Significance level\n", - "data_family = \"Copula\"\n", - "a = 0.5\n", - "b = 0.5\n", - "rng = np.random.RandomState(42)\n", - "torch.manual_seed(rng.randint(1000))\n", - "method_list = [\n", - " \"score-method\",\n", - "] # we do not take the deep method into account with the simple boot.\n", - "dataset_list = [\"COVID\"]\n", - "t_split_interval = 50\n", - "n_comp_sensors_list = [1]\n", - "window_size_list = [i * 100 for i in range(0, 11)]\n", - "n_comp_sensors = 1" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "standard-warrior", - "metadata": {}, - "outputs": [], - "source": [ - "n_trials = int(np.ceil((X_train.shape[0] - 2 * n_samples) / t_split_interval))\n", - "n_dim = X_train.shape[1]\n", - "sqrtn = int(np.floor(np.sqrt(n_dim)))\n", - "n_dataset_samples = X_train[n_samples:].shape[\n", - " 0\n", - "] # to account for taking out n_samples for reference dist, p\n", - "rng = np.random.RandomState(42)\n", - "torch.manual_seed(rng.randint(1000))\n", - "print(n_trials)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "banned-greenhouse", - "metadata": {}, - "outputs": [], - "source": [ - "transform_data = callable()\n", - "do_diff = False\n", - "do_power_transform = False\n", - "\n", - "dataset_name = \"gemini\"\n", - "for method in method_list:\n", - " for shuffle_data_set in [False, True]:\n", - " # Experiment Switches\n", - " if shuffle_data_set:\n", - " shuffle_string = \"time axis shuffled\"\n", - " experiment_name = f\"time-boot-{method}-time-axis-shuffled-on-{dataset_name}\"\n", - " else:\n", - " shuffle_string = \"time axis unshuffled\"\n", - " experiment_name = (\n", - " f\"time-boot-{method}-time-axis-unshuffled-on-{dataset_name}\"\n", - " )\n", - " print()\n", - " print(\n", - " f\"Starting {method} on {dataset_name} dataset with \\\n", - " {shuffle_string} and simple boot\",\n", - " )\n", - "\n", - " n_trials = int(np.ceil((X_train.shape[0] - 2 * n_samples) / t_split_interval))\n", - " n_dim = X_train.shape[1]\n", - " sqrtn = int(np.floor(np.sqrt(n_dim)))\n", - " n_dataset_samples = X_train[n_samples:].shape[\n", - " 0\n", - " ] # to account for taking out n_samples for reference dist, p\n", - "\n", - " # Attack testing\n", - " rng = np.random.RandomState(42)\n", - " torch.manual_seed(rng.randint(1000))\n", - "\n", - " time_list = np.zeros(n_trials)\n", - " global_truth = np.zeros(n_trials)\n", - " detection = np.zeros(n_trials)\n", - " detection_results = np.zeros(shape=(n_dim, n_trials, 3))\n", - " j_attack = rng.choice(np.arange(n_dim), replace=True, size=n_trials)\n", - " for idx, feature in enumerate(j_attack[: int(n_trials / 2)]):\n", - " detection_results[feature, idx, 1] = 1 # recording where attacks happen\n", - " global_truth[idx] = 1\n", - "\n", - " exception_occurred = 0\n", - " exception_vector = np.full(shape=(n_trials), fill_value=False)\n", - " for test_idx, split_idx in enumerate(\n", - " range(0, X_train.shape[0] - 2 * n_samples, t_split_interval),\n", - " ):\n", - " start = time()\n", - " test_idx = int(test_idx)\n", - " split_idx = int(split_idx)\n", - " slice1 = split_idx\n", - " slice2 = split_idx + 2 * n_samples\n", - " pq = X_train[slice1:slice2] # Two sets of samples\n", - " pq = transform_data(\n", - " pq,\n", - " do_diff=do_diff,\n", - " do_power_transform=do_power_transform,\n", - " )\n", - " p = pq[:n_samples]\n", - " q = pq[n_samples : n_samples * 2].copy()\n", - "\n", - " if np.any(detection_results[:, test_idx, 1] == 1): # attack!\n", - " attacked_features = j_attack[test_idx]\n", - " q[:, attacked_features] = rng.permutation(\n", - " q[:, attacked_features],\n", - " ) # permutes q\n", - "\n", - " # Bootstrap every time\n", - " fsd = FeatureShiftDetector(\n", - " p,\n", - " q,\n", - " rng=rng,\n", - " samples_generator=np.nan,\n", - " detection_method=method,\n", - " n_bootstrap_runs=n_bootstrap_runs,\n", - " n_conditional_expectation=n_conditional_expectation,\n", - " n_attacks=np.nan,\n", - " alpha=alpha,\n", - " j_attack=np.nan,\n", - " attack_testing=False,\n", - " )\n", - " bonferroni_threshold_vector = fsd.bonferroni_threshold_vector\n", - " threshold_vector = fsd.threshold_vector\n", - " bootstrap_score_means_vector = fsd.bootstrap_distribution.mean(axis=0)\n", - " bootstrap_score_std_vector = (\n", - " np.std(fsd.bootstrap_distribution, axis=0) + 1e-5\n", - " )\n", - "\n", - " # now check after getting new threshold\n", - " score_vector = np.array(fsd.get_score(p, q))\n", - " detection_results[:, test_idx, 0] = score_vector\n", - " # predicting attack\n", - " if np.any(score_vector >= bonferroni_threshold_vector):\n", - " detection[test_idx] = 1\n", - " normalized_score_vector = (\n", - " score_vector - bootstrap_score_means_vector\n", - " ) / bootstrap_score_std_vector\n", - " attacked_features = normalized_score_vector.argsort()[-1]\n", - " detection_results[attacked_features, test_idx, 2] = 1\n", - " time_list[test_idx] = time() - start\n", - "\n", - " # Recording Attack Results\n", - " confusion_tensor = np.zeros(shape=(n_dim, 2, 2))\n", - " for feature_idx, feature_results in enumerate(detection_results):\n", - " confusion_tensor[feature_idx] = sklearn_confusion_matrix(\n", - " feature_results[:, 1],\n", - " feature_results[:, 2],\n", - " labels=[0, 1],\n", - " )\n", - "\n", - " # overall detection confusion matrix\n", - " global_detection_confusion_matrix = sklearn_confusion_matrix(\n", - " global_truth,\n", - " detection,\n", - " labels=[0, 1],\n", - " )\n", - "\n", - " full_tn, full_fp, full_fn, full_tp = confusion_tensor.sum(axis=0).flatten()\n", - " micro_precision = full_tp / (full_tp + full_fp)\n", - " micro_recall = full_tp / (full_tp + full_fn)\n", - "\n", - " if shuffle_data_set:\n", - " print(\"Time axis shuffled\")\n", - " else:\n", - " print(\"Time axis unshuffled\")\n", - "\n", - " tn, fp, fn, tp = global_detection_confusion_matrix.flatten()\n", - " detection_precision = tp / (tp + fp)\n", - " detection_recall = tp / (tp + fn)\n", - "\n", - " print(\"Results for: \", experiment_name)\n", - " print(f\"Precision: {detection_precision * 100:.2f}%\")\n", - " print(f\"Recall: {detection_recall * 100:.2f}%\")\n", - "\n", - " print(f\"Micro-precision: {micro_precision * 100:.2f}%\")\n", - " print(f\"Micro-recall: {micro_recall * 100:.2f}%\")\n", - "\n", - " print(f\"Avg time per test: {time_list.mean():.2f} sec\")\n", - " print(f\"Total time: {time_list.sum():.2f} sec\")\n", - "\n", - " # Saving Score Distributions\n", - " results_dict = {\n", - " \"detection_results\": detection_results,\n", - " \"global_confusion_matrix\": global_detection_confusion_matrix,\n", - " \"confusion_tensor\": confusion_tensor,\n", - " \"times\": time_list,\n", - " }\n", - " experiment_save_name = experiment_name + \"-results_dict.p\"\n", - " pickle.dump(\n", - " results_dict,\n", - " open(path.join(\"..\", \"..\", \"results\", experiment_save_name), \"wb\"),\n", - " )\n", - "print(f'Experiment completed at {strftime(\"%a, %d %b %Y %I:%M%p\", localtime())}')" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.9.7 ('cyclops-4J2PL5I8-py3.9')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - }, - "vscode": { - "interpreter": { - "hash": "bd2cd438e1c6ddffa3035fc73b17ac5cc0e0ea8897eb8be17cc645c6abf0c8cc" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/nbs/monitor/gemini_drift_experiments/synthetic_drift.ipynb b/nbs/monitor/gemini_drift_experiments/synthetic_drift.ipynb deleted file mode 100644 index ac60ab6f5..000000000 --- a/nbs/monitor/gemini_drift_experiments/synthetic_drift.ipynb +++ /dev/null @@ -1,441 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "6b2520a8-d4ad-4941-8ea7-71fdd631225f", - "metadata": {}, - "source": [ - "### Synthetic Drift Detection ###" - ] - }, - { - "cell_type": "markdown", - "id": "e8bb193b-16d6-4f63-b5d3-3744cd1380e4", - "metadata": {}, - "source": [ - "## Imports ## " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "checked-supervisor", - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "from drift_detector.detector import Detector\n", - "from drift_detector.experimenter import Experimenter\n", - "from drift_detector.plotter import plot_drift_samples_pval\n", - "from drift_detector.reductor import Reductor\n", - "from drift_detector.synthetic_applicator import (\n", - " SyntheticShiftApplicator,\n", - " apply_predefined_shift,\n", - ")\n", - "from drift_detector.tester import DCTester, TSTester\n", - "from gemini.query import get_gemini_data\n", - "from gemini.utils import get_label, import_dataset_hospital, normalize, process, scale" - ] - }, - { - "cell_type": "markdown", - "id": "stopped-relevance", - "metadata": {}, - "source": [ - "## Parameters ##" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "retained-characterization", - "metadata": {}, - "outputs": [], - "source": [ - "PATH = \"/mnt/nfs/project/delirium/drift_exp/JULY-04-2022/\"\n", - "MODEL_PATH = (\n", - " \"/mnt/nfs/project/delirium/drift_exp/JULY-04-2022/saved_models/random_lstm.pt\"\n", - ")\n", - "TIMESTEPS = 6\n", - "AGGREGATION_TYPE = \"time_flatten\"\n", - "CONTEXT_TYPE = \"lstm\"\n", - "REPRESENTATION = \"rf\"\n", - "HOSPITALS = [\"SMH\", \"MSH\", \"THPC\", \"THPM\", \"UHNTG\", \"UHNTW\", \"PMH\", \"SBK\"]\n", - "OUTCOME = \"mortality\"" - ] - }, - { - "cell_type": "markdown", - "id": "9b601256-e684-42ee-a209-63b67fa2031d", - "metadata": {}, - "source": [ - "## Query Data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f60048ac-3bb6-4499-a1fd-8fd1f8511aef", - "metadata": {}, - "outputs": [], - "source": [ - "admin_data, x, y = get_gemini_data(PATH)" - ] - }, - { - "cell_type": "markdown", - "id": "ece297c5-43e3-4fbd-a0d9-6763d0f13b81", - "metadata": {}, - "source": [ - "## Preprocess Data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "97a15f71-6b53-4e33-80a4-90bc6cb3b915", - "metadata": {}, - "outputs": [], - "source": [ - "# Get subset\n", - "(X_tr, y_tr), (X_val, y_val), (X_t, y_t), feats, admin_data = import_dataset_hospital(\n", - " admin_data,\n", - " x,\n", - " y,\n", - " \"random\",\n", - " OUTCOME,\n", - " HOSPITALS,\n", - ")\n", - "\n", - "# Normalize data\n", - "X_tr_normalized = normalize(admin_data, X_tr, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_val_normalized = normalize(admin_data, X_val, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_normalized = normalize(admin_data, X_t, AGGREGATION_TYPE, TIMESTEPS)\n", - "\n", - "if AGGREGATION_TYPE != \"time\":\n", - " # Get labels\n", - " y_tr = get_label(admin_data, X_tr, OUTCOME)\n", - " y_val = get_label(admin_data, X_val, OUTCOME)\n", - " y_t = get_label(admin_data, X_t, OUTCOME)\n", - "\n", - "# Scale data\n", - "X_tr_scaled = scale(X_tr_normalized)\n", - "X_val_scaled = scale(X_val_normalized)\n", - "X_t_scaled = scale(X_t_normalized)\n", - "\n", - "# Process data\n", - "X_tr_final = process(X_tr_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_val_final = process(X_val_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_final = process(X_t_scaled, AGGREGATION_TYPE, TIMESTEPS)" - ] - }, - { - "cell_type": "markdown", - "id": "c68dc343-a0f6-47a3-b9a4-ac3e3f0fbb63", - "metadata": {}, - "source": [ - "## Reductor" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c72515e2-2854-415a-9232-4c0bca3f6798", - "metadata": {}, - "outputs": [], - "source": [ - "DR_TECHNIQUE = input(\"Select dimensionality reduction technique: \")\n", - "\n", - "reductor = Reductor(\n", - " dr_method=DR_TECHNIQUE,\n", - " model_path=MODEL_PATH,\n", - " var_ret=0.8,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "308363ea-83be-4f59-ba5c-719e023d0b5c", - "metadata": {}, - "source": [ - "## Tester" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "04c8a3ca-5095-4595-91d8-c9e280c0e97a", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "TESTER_METHOD = input(\"Select test method: \")\n", - "tstesters = [\"lk\", \"lsdd\", \"mmd\", \"tabular\", \"ctx_mmd\", \"chi2\", \"fet\", \"ks\"]\n", - "dctesters = [\"spot_the_diff\", \"classifier\", \"classifier_uncertainty\"]\n", - "\n", - "if TESTER_METHOD in tstesters:\n", - " tester = TSTester(\n", - " tester_method=TESTER_METHOD,\n", - " )\n", - "elif TESTER_METHOD in dctesters:\n", - " MODEL_METHOD = input(\"Select model method: \")\n", - " tester = DCTester(\n", - " tester_method=TESTER_METHOD,\n", - " model_method=MODEL_METHOD,\n", - " )\n", - "\n", - " if MODEL_METHOD == \"ctx_mmd\":\n", - " CONTEXT_TYPE = input(\"Select context type: \")\n", - "\n", - " if MODEL_METHOD == \"lk\":\n", - " REPRESENTATION = input(\"Select learned kernel representation: \")" - ] - }, - { - "cell_type": "markdown", - "id": "6d643707-c7f6-4b49-8806-4094ee0dbb0d", - "metadata": {}, - "source": [ - "## Detector " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5ebcbad5-ac85-4f1b-8792-669ec7ff54ca", - "metadata": {}, - "outputs": [], - "source": [ - "detector = Detector(\n", - " reductor=reductor,\n", - " tester=tester,\n", - " p_val_threshold=0.05,\n", - ")\n", - "detector.fit(X_tr_final)" - ] - }, - { - "cell_type": "markdown", - "id": "154c9739-57ea-4b97-aae8-b6e65e331499", - "metadata": {}, - "source": [ - "## SyntheticShiftApplicator" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ccc93e4b-4505-4b0b-a79a-9b8768cfc80d", - "metadata": {}, - "outputs": [], - "source": [ - "shiftapplicator = SyntheticShiftApplicator(\n", - " shift_type=\"gn_shift\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "6606a732-0139-409b-8a28-116af000907c", - "metadata": {}, - "source": [ - "## Experimenter" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "74f47fec-829e-4d46-bc0b-14c7e67a035a", - "metadata": {}, - "outputs": [], - "source": [ - "experimenter_custom = Experimenter(\n", - " detector=detector,\n", - " shiftapplicator=shiftapplicator,\n", - " admin_data=admin_data,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "addeb3e3-62a7-4af0-80ae-3d8af706c031", - "metadata": {}, - "source": [ - "## Run custom shift experiment" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "35f88d98-d833-4e51-85a3-e6f0468bd5c5", - "metadata": {}, - "outputs": [], - "source": [ - "X_t_final_shifted = experimenter_custom.apply_synthetic_shift(\n", - " X_t_final,\n", - " shift_type=\"gn_shift\",\n", - " delta=0.01,\n", - " noise_amt=0.01,\n", - " clip=False,\n", - ")\n", - "\n", - "results = experimenter_custom.detect_shift_samples(X_t_final_shifted)\n", - "\n", - "results" - ] - }, - { - "cell_type": "markdown", - "id": "sudden-topic", - "metadata": {}, - "source": [ - "## Run predefined shift experiments ##" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3164d25a-cbea-4b16-80c5-40ae358cc481", - "metadata": {}, - "outputs": [], - "source": [ - "SHIFT = input(\"Select shift experiment: \")\n", - "\n", - "if SHIFT == \"ko_shift\":\n", - " shifts = [\"ko_shift_0.1\", \"ko_shift_0.5\", \"ko_shift_1.0\"]\n", - "elif SHIFT == \"small_gn_shift\":\n", - " shifts = [\"small_gn_shift_0.1\", \"small_gn_shift_0.5\", \"small_gn_shift_1.0\"]\n", - "elif SHIFT == \"medium_gn_shift\":\n", - " shifts = [\"medium_gn_shift_0.1\", \"medium_gn_shift_0.5\", \"medium_gn_shift_1.0\"]\n", - "elif SHIFT == \"large_gn_shift\":\n", - " shifts = [\"large_gn_shift_0.1\", \"large_gn_shift_0.5\", \"large_gn_shift_1.0\"]\n", - "elif SHIFT == \"mfa_shift\":\n", - " shifts = [\"mfa_shift_0.25\", \"mfa_shift_0.5\", \"mfa_shift_0.75\"]\n", - "elif SHIFT == \"cp_shift\":\n", - " shifts = [\"cp_shift_0.25\", \"cp_shift_0.75\"]\n", - "elif SHIFT == \"small_bn_shift\":\n", - " shifts = [\"small_bn_shift_0.1\", \"small_bn_shift_0.5\", \"small_bn_shift_1.0\"]\n", - "elif SHIFT == \"medium_bn_shift\":\n", - " shifts = [\"medium_bn_shift_0.1\", \"medium_bn_shift_0.5\", \"medium_bn_shift_1.0\"]\n", - "elif SHIFT == \"large_bn_shift\":\n", - " shifts = [\"large_bn_shift_0.1\", \"large_bn_shift_0.5\", \"large_bn_shift_1.0\"]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e876ef31-c3d9-4c03-adee-cc6fe0fdb67d", - "metadata": {}, - "outputs": [], - "source": [ - "experimenter_predefined = Experimenter(detector=detector, admin_data=admin_data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1ae0a5ae-5134-4026-9da1-80fdedcccbb6", - "metadata": {}, - "outputs": [], - "source": [ - "shift_results = {}\n", - "for _si, shift in enumerate(shifts):\n", - " X_t_final_shifted = X_t_final.copy()\n", - " X_t_final_shifted, _ = apply_predefined_shift(shift, X=X_t_final_shifted, y=y_t)\n", - " results = experimenter_predefined.detect_shift_samples(X_t_final_shifted)\n", - " shift_results.update({shift: results})" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7743e141-f5d0-4e27-aeea-c5fe6fa00673", - "metadata": {}, - "outputs": [], - "source": [ - "X_t_final_shifted = X_t_final.copy()\n", - "X_t_final_shifted, _ = apply_predefined_shift(shift, X=X_t_final_shifted, y=y_t)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c601f173-7f3c-49dc-b42a-c1adde1e6840", - "metadata": {}, - "outputs": [], - "source": [ - "fig, ax = plt.subplots(figsize=(11, 6))\n", - "plt.hist(X_val_final[:, 0], bins=50, alpha=0.5, label=\"val\", density=True)\n", - "plt.hist(X_t_final[:, 0], bins=50, alpha=0.5, label=\"test\", density=True)\n", - "plt.hist(X_t_final_shifted[:, 0], bins=50, alpha=0.5, label=\"test+noise\", density=True)\n", - "fig.legend(loc=\"upper right\")\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b3a02ceb", - "metadata": {}, - "outputs": [], - "source": [ - "experimenter2 = Experimenter(detector=detector, admin_data=admin_data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "20a59f6d-a296-4736-b4ea-a5baa07d0879", - "metadata": {}, - "outputs": [], - "source": [ - "experimenter2.detect_shift_sample(X_t_final_shifted, sample=100)" - ] - }, - { - "cell_type": "markdown", - "id": "b2ce64ff-70fa-46f1-82bf-277ed96e9dfa", - "metadata": {}, - "source": [ - "## Plot shift experiments" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "horizontal-holiday", - "metadata": {}, - "outputs": [], - "source": [ - "plot_drift_samples_pval(shift_results, 0.05)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.4" - }, - "vscode": { - "interpreter": { - "hash": "bd2cd438e1c6ddffa3035fc73b17ac5cc0e0ea8897eb8be17cc645c6abf0c8cc" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/nbs/monitor/mortality/data_pipeline.ipynb b/nbs/monitor/mortality/data_pipeline.ipynb deleted file mode 100644 index 0c4288eaa..000000000 --- a/nbs/monitor/mortality/data_pipeline.ipynb +++ /dev/null @@ -1,186 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "480200fd-0a26-4cff-92e7-977bd499eddf", - "metadata": {}, - "source": [ - "# Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "34d3377b-4db4-48bd-a398-b7dc2a96a53b", - "metadata": {}, - "outputs": [], - "source": [ - "import time\n", - "\n", - "import pandas as pd\n", - "from drift_detection.gemini.mortality.constants import (\n", - " CLEANED_DIR,\n", - " ENCOUNTERS_FILE,\n", - " OUTCOME_DEATH,\n", - " QUERIED_DIR,\n", - " TARGET_TIMESTAMP,\n", - ")\n", - "from drift_detection.gemini.query import main\n", - "\n", - "from cyclops.processors.clean import normalize_names, normalize_values\n", - "from cyclops.processors.column_names import (\n", - " DISCHARGE_TIMESTAMP,\n", - " ENCOUNTER_ID,\n", - " EVENT_NAME,\n", - " EVENT_VALUE,\n", - ")\n", - "from cyclops.processors.feature.split import intersect_datasets\n", - "from cyclops.utils.file import join, save_dataframe" - ] - }, - { - "cell_type": "markdown", - "id": "8af9d390-0916-4710-8e0d-6ae4046db6a3", - "metadata": {}, - "source": [ - "# Query" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2c824cdd-3399-4c49-9c29-666287916c1b", - "metadata": {}, - "outputs": [], - "source": [ - "t = time.time()\n", - "cohort, events = main()\n", - "print(time.time() - t)\n", - "cohort" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d97ff77b-2599-4b25-9506-0ce1460f0c0a", - "metadata": {}, - "outputs": [], - "source": [ - "cohort[OUTCOME_DEATH].sum() / len(cohort)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8034cecf-060c-48d5-8688-15255628fbbc", - "metadata": {}, - "outputs": [], - "source": [ - "events" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "032186ed-9141-442b-83de-76111c0f25dd", - "metadata": {}, - "outputs": [], - "source": [ - "# Intersect over encounter IDs to get only those encounters common to both\n", - "cohort, events = intersect_datasets([cohort, events], ENCOUNTER_ID)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9faedb44-ba80-43b1-838e-a1ce7dab2119", - "metadata": {}, - "outputs": [], - "source": [ - "save_dataframe(events, join(QUERIED_DIR, \"batch_0000.parquet\"))" - ] - }, - { - "cell_type": "markdown", - "id": "26bed3c7-e88e-40a2-815e-a88cfb5856cf", - "metadata": {}, - "source": [ - "# Clean / Preprocess" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1bc14ee7-d120-47ae-a3d1-b0253fb77cec", - "metadata": {}, - "outputs": [], - "source": [ - "death_events = cohort[cohort[OUTCOME_DEATH] == True] # noqa: E712\n", - "death_events = death_events[[ENCOUNTER_ID, DISCHARGE_TIMESTAMP]]\n", - "death_events = death_events.rename({DISCHARGE_TIMESTAMP: TARGET_TIMESTAMP}, axis=1)\n", - "cohort = pd.merge(cohort, death_events, on=ENCOUNTER_ID, how=\"left\")\n", - "cohort" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4ec131c4-7828-4a12-bb3b-2caea66ff579", - "metadata": {}, - "outputs": [], - "source": [ - "save_dataframe(cohort, ENCOUNTERS_FILE)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "29c7fa3b-0964-4e24-a00c-bf0600a2ac28", - "metadata": {}, - "outputs": [], - "source": [ - "# Normalize names and string values\n", - "events[EVENT_NAME] = normalize_names(events[EVENT_NAME])\n", - "events[EVENT_VALUE] = normalize_values(events[EVENT_VALUE])\n", - "\n", - "# Convert values to numeric, dropping those which can't be converted\n", - "events[EVENT_VALUE] = pd.to_numeric(events[EVENT_VALUE], errors=\"coerce\")\n", - "print(\"Length before:\", len(events))\n", - "events = events[~events[EVENT_VALUE].isna()]\n", - "print(\"Length after:\", len(events))\n", - "events" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "da6ec8be-8583-43c0-987e-5b8bf01e5275", - "metadata": {}, - "outputs": [], - "source": [ - "save_dataframe(events, join(CLEANED_DIR, \"batch_0000.parquet\"))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "cyclops-KKtuQLwg-py3.9", - "language": "python", - "name": "cyclops-kktuqlwg-py3.9" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/nbs/monitor/mortality/preprocessing.ipynb b/nbs/monitor/mortality/preprocessing.ipynb deleted file mode 100644 index 45f601698..000000000 --- a/nbs/monitor/mortality/preprocessing.ipynb +++ /dev/null @@ -1,1037 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "10841305-2ca3-4e00-b9ec-dd7cf3c0d69e", - "metadata": {}, - "source": [ - "# Shared notebook for processing temporal features." - ] - }, - { - "cell_type": "markdown", - "id": "1b737fc6-e557-453b-9967-8cefe2c8d80a", - "metadata": {}, - "source": [ - "# Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a0d37a3f-c67e-4548-93bd-85d39f7d8d6b", - "metadata": {}, - "outputs": [], - "source": [ - "from functools import reduce\n", - "\n", - "import numpy as np\n", - "import pandas as pd\n", - "from drift_detection.gemini.utils import get_use_case_params\n", - "\n", - "from cyclops.processors.aggregate import (\n", - " Aggregator,\n", - " tabular_as_aggregated,\n", - " timestamp_ffill_agg,\n", - ")\n", - "from cyclops.processors.column_names import (\n", - " ADMIT_TIMESTAMP,\n", - " DISCHARGE_TIMESTAMP,\n", - " ENCOUNTER_ID,\n", - " EVENT_NAME,\n", - " EVENT_TIMESTAMP,\n", - " EVENT_VALUE,\n", - " RESTRICT_TIMESTAMP,\n", - " TIMESTEP,\n", - ")\n", - "from cyclops.processors.constants import ALL, FEATURES, MEAN, NUMERIC, ORDINAL, STANDARD\n", - "from cyclops.processors.feature.feature import TabularFeatures, TemporalFeatures\n", - "from cyclops.processors.feature.vectorize import (\n", - " Vectorized,\n", - " intersect_vectorized,\n", - " split_vectorized,\n", - " vec_index_exp,\n", - ")\n", - "from cyclops.processors.impute import np_ffill_bfill, np_fill_null_num\n", - "from cyclops.utils.file import (\n", - " join,\n", - " load_dataframe,\n", - " load_pickle,\n", - " save_dataframe,\n", - " save_pickle,\n", - " yield_dataframes,\n", - " yield_pickled_files,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "fb49c62b-ceae-4827-90c8-d4411619f089", - "metadata": {}, - "source": [ - "# Choose dataset and use-case" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "85e8cb82-ddad-4362-a450-6c351de6ffd4", - "metadata": {}, - "outputs": [], - "source": [ - "DATASET = \"gemini\"\n", - "USE_CASE = \"mortality\"\n", - "\n", - "use_case_params = get_use_case_params(DATASET, USE_CASE)\n", - "input(f\"WARNING: LOADING CONSTANTS FROM {use_case_params}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d25686f7-9752-48e1-93a5-5d2ad6a2b52f", - "metadata": {}, - "outputs": [], - "source": [ - "cohort = load_dataframe(use_case_params.ENCOUNTERS_FILE)\n", - "cohort = cohort.reset_index(drop=True)\n", - "cohort.head(5)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8ccea95f-17af-4b70-8c6a-990a1ebe00a0", - "metadata": {}, - "outputs": [], - "source": [ - "tab_features = TabularFeatures(\n", - " data=cohort,\n", - " features=use_case_params.TAB_FEATURES,\n", - " by=ENCOUNTER_ID,\n", - " force_types=use_case_params.TAB_FEATURES_TYPES,\n", - ")\n", - "\n", - "numeric_features = tab_features.features_by_type(NUMERIC)\n", - "ordinal_features = tab_features.features_by_type(ORDINAL)\n", - "\n", - "if len(ordinal_features) > 0:\n", - " print(ordinal_features[0], \"mapping:\")\n", - " print(tab_features.meta[ordinal_features[0]].get_mapping())\n", - "\n", - "tab_vectorized = tab_features.vectorize(to_binary_indicators=ordinal_features)\n", - "save_pickle(tab_vectorized, use_case_params.TAB_VECTORIZED_FILE)\n", - "save_pickle(tab_features, use_case_params.TAB_FEATURES_FILE)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4d3f7441-1fa5-4f3e-86d2-13d5ff016849", - "metadata": {}, - "outputs": [], - "source": [ - "load_dataframe(use_case_params.ENCOUNTERS_FILE)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "70b0b8aa-701d-400b-bbe8-f6ceb0c484f4", - "metadata": {}, - "outputs": [], - "source": [ - "timestamps = load_dataframe(use_case_params.ENCOUNTERS_FILE)[\n", - " [\n", - " ENCOUNTER_ID,\n", - " ADMIT_TIMESTAMP,\n", - " DISCHARGE_TIMESTAMP,\n", - " use_case_params.TARGET_TIMESTAMP,\n", - " ]\n", - "]\n", - "start_timestamps = (\n", - " timestamps[[ENCOUNTER_ID, ADMIT_TIMESTAMP]]\n", - " .set_index(ENCOUNTER_ID)\n", - " .rename({ADMIT_TIMESTAMP: RESTRICT_TIMESTAMP}, axis=1)\n", - ")\n", - "start_timestamps" - ] - }, - { - "cell_type": "markdown", - "id": "2187007a-7fc8-4c33-b60a-44cb05562c16", - "metadata": { - "tags": [] - }, - "source": [ - "# Temporal-specific processing" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "10e89fbb-fffe-462c-a1fd-f9f9ce79dcc0", - "metadata": {}, - "outputs": [], - "source": [ - "# Determine which events to keep\n", - "# Keep only the most popular events where the values are not null\n", - "all_top_events = []\n", - "for _i, events in enumerate(yield_dataframes(use_case_params.CLEANED_DIR, log=False)):\n", - " top_events = (\n", - " events[EVENT_NAME][~events[EVENT_VALUE].isna()]\n", - " .value_counts()[: use_case_params.TOP_N_EVENTS]\n", - " .index\n", - " )\n", - "\n", - " all_top_events.append(top_events)\n", - "\n", - " del events\n", - "\n", - "# Take only the events common to every file\n", - "top_events = reduce(np.intersect1d, tuple(all_top_events))\n", - "\n", - "top_events" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "44595c02-6526-4e6e-930d-53d2e4a3481a", - "metadata": {}, - "outputs": [], - "source": [ - "len(top_events)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f0de07f3-12a3-492d-aa2b-5260480ff962", - "metadata": {}, - "outputs": [], - "source": [ - "aggregator = Aggregator(\n", - " aggfuncs={EVENT_VALUE: MEAN},\n", - " timestamp_col=EVENT_TIMESTAMP,\n", - " time_by=ENCOUNTER_ID,\n", - " agg_by=[ENCOUNTER_ID, EVENT_NAME],\n", - " timestep_size=use_case_params.TIMESTEP_SIZE,\n", - " window_duration=use_case_params.WINDOW_DURATION,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "afa902f2-6213-4e7d-8ae6-b9c926aa232a", - "metadata": {}, - "outputs": [], - "source": [ - "# Aggregate\n", - "skip_n = 0\n", - "generator = yield_dataframes(use_case_params.CLEANED_DIR, skip_n=skip_n, log=False)\n", - "\n", - "for save_count, events in enumerate(generator):\n", - " # Take only the top events\n", - " events = events[events[EVENT_NAME].isin(top_events)]\n", - "\n", - " # Aggregate\n", - " events = events.reset_index(drop=True)\n", - " tmp_features = TemporalFeatures(\n", - " events,\n", - " features=EVENT_VALUE,\n", - " by=[ENCOUNTER_ID, EVENT_NAME],\n", - " timestamp_col=EVENT_TIMESTAMP,\n", - " aggregator=aggregator,\n", - " )\n", - "\n", - " aggregated = tmp_features.aggregate(window_start_time=start_timestamps)\n", - "\n", - " save_dataframe(\n", - " aggregated,\n", - " join(use_case_params.AGGREGATED_DIR, \"batch_\" + f\"{save_count + skip_n:04d}\"),\n", - " )\n", - " del events" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ff04dfbd-09bf-4f95-be5c-365316989cb6", - "metadata": {}, - "outputs": [], - "source": [ - "# Vectorize\n", - "skip_n = 0\n", - "generator = yield_dataframes(use_case_params.AGGREGATED_DIR, skip_n=skip_n, log=False)\n", - "for save_count, aggregated in enumerate(generator):\n", - " vec = aggregator.vectorize(aggregated)\n", - " save_pickle(\n", - " vec,\n", - " join(use_case_params.VECTORIZED_DIR, \"batch_\" + f\"{save_count + skip_n:04d}\"),\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6136fa3e-5df9-4ae0-8306-9e7a045c6742", - "metadata": {}, - "outputs": [], - "source": [ - "# Take all Vectorized objects and turn them into a single object\n", - "vecs = list(yield_pickled_files(use_case_params.VECTORIZED_DIR))\n", - "encounter_axis = vecs[0].get_axis(ENCOUNTER_ID)\n", - "res = np.concatenate([vec.data for vec in vecs], axis=encounter_axis)\n", - "indexes = vecs[0].indexes\n", - "indexes[encounter_axis] = np.concatenate([vec.indexes[encounter_axis] for vec in vecs])\n", - "temp_vectorized = Vectorized(res, indexes, vecs[0].axis_names)\n", - "del res" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8cc24eed-f576-4057-9464-9a91b906ceb5", - "metadata": {}, - "outputs": [], - "source": [ - "temp_vectorized.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aa840bb6-043c-466d-8ff1-0418a3ccd394", - "metadata": {}, - "outputs": [], - "source": [ - "temp_vectorized.axis_names" - ] - }, - { - "cell_type": "markdown", - "id": "0d9f6c14-7246-41a9-8f8b-cfcdaad02355", - "metadata": {}, - "source": [ - "## Target creation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "09ffdc46-6f93-45a1-a9dd-7621284ac655", - "metadata": {}, - "outputs": [], - "source": [ - "def compute_timestep(timestamps, event):\n", - " timestamps[f\"{event}_after_admit\"] = timestamps[event] - timestamps[ADMIT_TIMESTAMP]\n", - " timestamps[f\"{event}_timestep\"] = (\n", - " timestamps[f\"{event}_after_admit\"]\n", - " / pd.Timedelta(f\"{use_case_params.TIMESTEP_SIZE} hour\")\n", - " ).apply(np.floor)\n", - " return timestamps\n", - "\n", - "\n", - "timestamps[\"target\"] = timestamps[use_case_params.TARGET_TIMESTAMP] - pd.DateOffset(\n", - " hours=use_case_params.PREDICT_OFFSET,\n", - ")\n", - "timestamps = compute_timestep(timestamps, \"target\")\n", - "timestamps = compute_timestep(timestamps, DISCHARGE_TIMESTAMP)\n", - "timestamps" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4008eb58-3a91-4009-8fe9-b1c1ac3ab645", - "metadata": {}, - "outputs": [], - "source": [ - "timestamps[~timestamps[use_case_params.TARGET_TIMESTAMP].isna()]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2f526901-4292-4be6-ae40-e4fe5ea2864f", - "metadata": {}, - "outputs": [], - "source": [ - "encounter_order = pd.Series(temp_vectorized.get_index(ENCOUNTER_ID))\n", - "encounter_order = encounter_order.rename(ENCOUNTER_ID).to_frame()\n", - "encounter_order" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fb3051ae-e472-45cc-8586-a6ce37997329", - "metadata": {}, - "outputs": [], - "source": [ - "discharge_timestep = DISCHARGE_TIMESTAMP + \"_timestep\"\n", - "timesteps = timestamps[[ENCOUNTER_ID, \"target_timestep\", discharge_timestep]]\n", - "aligned_timestamps = pd.merge(encounter_order, timesteps, on=ENCOUNTER_ID, how=\"left\")\n", - "aligned_timestamps" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "461e82a1-9f50-4099-9683-d482974db839", - "metadata": {}, - "outputs": [], - "source": [ - "num_timesteps = int(use_case_params.WINDOW_DURATION / use_case_params.TIMESTEP_SIZE)\n", - "shape = (len(aligned_timestamps), num_timesteps)\n", - "\n", - "arr1 = timestamp_ffill_agg(\n", - " aligned_timestamps[\"target_timestep\"],\n", - " num_timesteps,\n", - " fill_nan=2,\n", - ")\n", - "arr2 = timestamp_ffill_agg(\n", - " aligned_timestamps[discharge_timestep],\n", - " num_timesteps,\n", - " val=-1,\n", - " fill_nan=2,\n", - ")\n", - "targets = np.minimum(arr1, arr2)\n", - "targets[targets == 2] = 0\n", - "targets[126:146]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bede3ff8-3d76-4139-914a-de8239d20d73", - "metadata": {}, - "outputs": [], - "source": [ - "aligned_timestamps.iloc[126:146]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cd397c63-3626-4feb-9b86-c2b6a24216f4", - "metadata": {}, - "outputs": [], - "source": [ - "targets = np.expand_dims(np.expand_dims(targets, 0), 2)\n", - "targets.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5cc19e67-90c7-48fa-b1f7-bc06bf64976b", - "metadata": {}, - "outputs": [], - "source": [ - "temp_vectorized.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "80730791-f362-4c45-8533-33f83004c479", - "metadata": {}, - "outputs": [], - "source": [ - "# Include target\n", - "temp_vectorized = temp_vectorized.concat_over_axis(\n", - " EVENT_NAME,\n", - " targets,\n", - " use_case_params.TEMP_TARGETS,\n", - ")\n", - "temp_vectorized.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "14af4cdc-0e82-4b7b-b9d6-69c91b2c8d25", - "metadata": {}, - "outputs": [], - "source": [ - "only_targets = temp_vectorized.take_with_index(EVENT_NAME, use_case_params.TEMP_TARGETS)\n", - "assert np.isnan(only_targets.data).sum() == 0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dd73ccf9-d87c-4d45-9dc8-cb11271b0896", - "metadata": {}, - "outputs": [], - "source": [ - "save_pickle(temp_vectorized, use_case_params.TEMP_VECTORIZED_FILE)" - ] - }, - { - "cell_type": "markdown", - "id": "c103337f-dbd8-4f04-96a7-b6e721cd3642", - "metadata": {}, - "source": [ - "# Combined processing" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d2b89e1c-a1d7-4dd1-b37c-6e78698e3e16", - "metadata": {}, - "outputs": [], - "source": [ - "temp_vectorized = load_pickle(use_case_params.TEMP_VECTORIZED_FILE)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a7197544-6b93-4856-8959-c1a8228557db", - "metadata": {}, - "outputs": [], - "source": [ - "tab = tab_features.get_data(to_binary_indicators=ordinal_features).reset_index()\n", - "\n", - "# Take only the encounters with temporal events\n", - "tab = tab[np.in1d(tab[ENCOUNTER_ID].values, temp_vectorized.get_index(ENCOUNTER_ID))]\n", - "\n", - "# Aggregate tabular\n", - "tab_aggregated = tabular_as_aggregated(\n", - " tab=tab,\n", - " index=ENCOUNTER_ID,\n", - " var_name=EVENT_NAME,\n", - " value_name=EVENT_VALUE,\n", - " strategy=ALL,\n", - " num_timesteps=aggregator.window_duration // aggregator.timestep_size,\n", - ")\n", - "tab_aggregated" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0c547a8b-5ab5-4a85-8eb3-59aaa87f96fe", - "metadata": {}, - "outputs": [], - "source": [ - "# Vectorize tabular\n", - "tab_aggregated_vec = aggregator.vectorize(tab_aggregated)\n", - "tab_aggregated_vec.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f8e7b067-e628-49f7-af5d-e2ea297add57", - "metadata": {}, - "outputs": [], - "source": [ - "temp_vectorized.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2cf0b6e6-0633-4030-9863-7885d88c344a", - "metadata": {}, - "outputs": [], - "source": [ - "tab_aggregated_vec.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6406a4f3-b728-4e48-a763-b89253e73af9", - "metadata": {}, - "outputs": [], - "source": [ - "# Combine\n", - "comb_vectorized = temp_vectorized.concat_over_axis(\n", - " EVENT_NAME,\n", - " tab_aggregated_vec.data,\n", - " tab_aggregated_vec.get_index(EVENT_NAME),\n", - ")\n", - "comb_vectorized.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "31dfd482-0b8f-4c30-8d56-7e6b89c4f093", - "metadata": {}, - "outputs": [], - "source": [ - "# Don't include any of the tabular targets - split out to avoid label leakage\n", - "comb_vectorized, _ = comb_vectorized.split_out(EVENT_NAME, use_case_params.TAB_TARGETS)\n", - "comb_vectorized.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6e8f266a-007a-4e50-b525-7c47fb475d9f", - "metadata": {}, - "outputs": [], - "source": [ - "comb_vectorized.get_index(EVENT_NAME)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8cd21b0c-4296-4d92-8e0e-9306011b0630", - "metadata": {}, - "outputs": [], - "source": [ - "np.isnan(tab_aggregated_vec.data).sum() / tab_aggregated_vec.data.size" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7d283d99-9769-4f61-b423-6526a6928a9f", - "metadata": {}, - "outputs": [], - "source": [ - "tab_aggregated_vec.data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "69aee9c3-8304-41e5-9a35-8885732fdb63", - "metadata": {}, - "outputs": [], - "source": [ - "np.isnan(temp_vectorized.data).sum() / temp_vectorized.data.size" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9eecbad6-29c9-471b-a42c-28fa177e6102", - "metadata": {}, - "outputs": [], - "source": [ - "np.isnan(comb_vectorized.data).sum() / comb_vectorized.data.size" - ] - }, - { - "cell_type": "markdown", - "id": "6fb0a06e-d146-45af-89cd-a90964a6b2df", - "metadata": {}, - "source": [ - "# Prepare splits" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1401ee79-1a30-41fd-9c42-fc13a6425649", - "metadata": {}, - "outputs": [], - "source": [ - "tab_vectorized.shape, temp_vectorized.shape, comb_vectorized.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3e4c1178-31af-4713-b3a2-fd5f5aacf845", - "metadata": {}, - "outputs": [], - "source": [ - "tab_vectorized, temp_vectorized, comb_vectorized = intersect_vectorized(\n", - " [tab_vectorized, temp_vectorized, comb_vectorized],\n", - " axes=ENCOUNTER_ID,\n", - ")\n", - "tab_vectorized.shape, temp_vectorized.shape, comb_vectorized.shape" - ] - }, - { - "cell_type": "markdown", - "id": "5d0eaa25-0014-4212-b2ae-794cd70e915d", - "metadata": {}, - "source": [ - "Take only the encounters available in all of the datasets and align the datasets over encounters." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b97c5aba-847e-4f58-9bce-69222f538f12", - "metadata": {}, - "outputs": [], - "source": [ - "# Normalize only numeric features (e.g., not binary indicators)\n", - "# Note: Normalization is not occurring here, we are only doing the setup\n", - "normalizer_map = {feat: STANDARD for feat in numeric_features}\n", - "\n", - "tab_vectorized.add_normalizer(\n", - " FEATURES,\n", - " normalizer_map=normalizer_map,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "174bb4b0-8948-42d6-a477-b0792a9b7186", - "metadata": {}, - "outputs": [], - "source": [ - "# Normalize all events\n", - "# Note: Normalization is not occurring here, we are only doing the setup\n", - "temp_vectorized.add_normalizer(\n", - " EVENT_NAME,\n", - " normalization_method=STANDARD,\n", - ")\n", - "\n", - "comb_vectorized.add_normalizer(\n", - " EVENT_NAME,\n", - " normalization_method=STANDARD,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "ff6663ae-1ed5-445e-a1d6-9021154fe1a7", - "metadata": {}, - "source": [ - "## Dataset splits" - ] - }, - { - "cell_type": "markdown", - "id": "839995aa-40cb-449c-a057-ad7719998ecf", - "metadata": {}, - "source": [ - "Split into training, validation, and testing datasets such that the tabular and temporal encounters remain aligned." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5a4ec3d2-a221-41ce-93ae-9f5a3018eb00", - "metadata": {}, - "outputs": [], - "source": [ - "tab_splits, temp_splits, comb_splits = split_vectorized(\n", - " [tab_vectorized, temp_vectorized, comb_vectorized],\n", - " use_case_params.SPLIT_FRACTIONS,\n", - " axes=ENCOUNTER_ID,\n", - ")\n", - "tab_train, tab_val, tab_test = tab_splits\n", - "temp_train, temp_val, temp_test = temp_splits\n", - "comb_train, comb_val, comb_test = comb_splits" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "37a6a36d-499d-447e-93eb-3ac002c986c6", - "metadata": {}, - "outputs": [], - "source": [ - "tab_train.shape, tab_val.shape, tab_test.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "edfd710c-b21d-4dc4-b8e4-550235570768", - "metadata": {}, - "outputs": [], - "source": [ - "temp_train.shape, temp_val.shape, temp_test.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "faea0fb5-526a-4f42-af8d-d4ee9c9e60df", - "metadata": {}, - "outputs": [], - "source": [ - "comb_train.shape, comb_val.shape, comb_test.shape" - ] - }, - { - "cell_type": "markdown", - "id": "203ec0ad-a35a-49df-8979-a32779c3ac13", - "metadata": {}, - "source": [ - "## Split features/targets" - ] - }, - { - "cell_type": "markdown", - "id": "9e2f20f3-fe51-41f5-be9f-c74d32519787", - "metadata": {}, - "source": [ - "Split out the targets in the temporal data." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aca6b88f-8f3b-46d2-9d75-3fe8eec1745b", - "metadata": {}, - "outputs": [], - "source": [ - "tab_train_X, tab_train_y = tab_train.split_out(FEATURES, use_case_params.TAB_TARGETS)\n", - "tab_train_X.shape, tab_train_y.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ec5a428e-fcab-4a49-b9c9-77fe5b56da4f", - "metadata": {}, - "outputs": [], - "source": [ - "tab_val_X, tab_val_y = tab_val.split_out(FEATURES, use_case_params.TAB_TARGETS)\n", - "tab_val_X.shape, tab_val_y.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "98adfca2-4347-475e-ad9d-aeb15a9f3c65", - "metadata": {}, - "outputs": [], - "source": [ - "tab_test_X, tab_test_y = tab_test.split_out(FEATURES, use_case_params.TAB_TARGETS)\n", - "tab_test_X.shape, tab_test_y.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7a251574-2635-42db-b42e-3d8b77143b64", - "metadata": {}, - "outputs": [], - "source": [ - "temp_train_X, temp_train_y = temp_train.split_out(\n", - " EVENT_NAME,\n", - " use_case_params.TEMP_TARGETS,\n", - ")\n", - "temp_train_X.shape, temp_train_y.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "54795d9d-c2a4-4dd4-9ff9-d3927a0339a3", - "metadata": {}, - "outputs": [], - "source": [ - "temp_val_X, temp_val_y = temp_val.split_out(EVENT_NAME, use_case_params.TEMP_TARGETS)\n", - "temp_val_X.shape, temp_val_y.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7857300c-e98e-4af3-89eb-4665e887c60c", - "metadata": {}, - "outputs": [], - "source": [ - "temp_test_X, temp_test_y = temp_test.split_out(EVENT_NAME, use_case_params.TEMP_TARGETS)\n", - "temp_test_X.shape, temp_test_y.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "24732d0c-cd51-44ed-b84a-ecb7d11588d7", - "metadata": {}, - "outputs": [], - "source": [ - "comb_train_X, comb_train_y = comb_train.split_out(\n", - " EVENT_NAME,\n", - " use_case_params.TEMP_TARGETS,\n", - ")\n", - "comb_train_X.shape, comb_train_y.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fb708d63-2fff-4f76-814a-376eb6f06969", - "metadata": {}, - "outputs": [], - "source": [ - "comb_val_X, comb_val_y = comb_val.split_out(EVENT_NAME, use_case_params.TEMP_TARGETS)\n", - "comb_val_X.shape, comb_val_y.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e86adf65-7931-49ce-ab02-eae544758a06", - "metadata": {}, - "outputs": [], - "source": [ - "comb_test_X, comb_test_y = comb_test.split_out(EVENT_NAME, use_case_params.TEMP_TARGETS)\n", - "comb_test_X.shape, comb_test_y.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2bacb495-8c8a-486a-95db-ca55c02597bf", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e901b017-237b-4cb9-85a7-04172d7dc65c", - "metadata": {}, - "outputs": [], - "source": [ - "def impute(temp_vec):\n", - " # Forward fill then backward fill to get rid of all of the timestep nulls\n", - " temp_vec.impute_over_axis(TIMESTEP, np_ffill_bfill)\n", - "\n", - " # Fill those all-null timesteps with feature mean\n", - " # (since forward and backward filling still leaves them all null)\n", - " axis = temp_vec.get_axis(EVENT_NAME)\n", - "\n", - " for i in range(temp_vec.data.shape[axis]):\n", - " index_exp = vec_index_exp[:, :, i]\n", - " data_slice = temp_vec.data[index_exp]\n", - " mean = np.nanmean(data_slice)\n", - " func = lambda x: np_fill_null_num(x, mean) # noqa: E731\n", - " temp_vec.impute_over_axis(TIMESTEP, func, index_exp=index_exp)\n", - "\n", - " return temp_vec\n", - "\n", - "\n", - "temp_train_X = impute(temp_train_X)\n", - "temp_val_X = impute(temp_val_X)\n", - "temp_test_X = impute(temp_test_X)\n", - "\n", - "comb_train_X = impute(comb_train_X)\n", - "comb_val_X = impute(comb_val_X)\n", - "comb_test_X = impute(comb_test_X)" - ] - }, - { - "cell_type": "markdown", - "id": "b45650c0-847e-4b13-974f-49f9b89cc5f3", - "metadata": {}, - "source": [ - "### Normalization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "47896217-2e2c-449b-9ac9-b3ce08028f91", - "metadata": {}, - "outputs": [], - "source": [ - "splits = (\n", - " tab_train_X,\n", - " tab_val_X,\n", - " tab_test_X,\n", - " temp_train_X,\n", - " temp_val_X,\n", - " temp_test_X,\n", - " comb_train_X,\n", - " comb_val_X,\n", - " comb_test_X,\n", - ")\n", - "\n", - "for split in splits:\n", - " split.fit_normalizer()\n", - " split.normalize()\n", - "\n", - "(\n", - " tab_train_X,\n", - " tab_val_X,\n", - " tab_test_X,\n", - " temp_train_X,\n", - " temp_val_X,\n", - " temp_test_X,\n", - " comb_train_X,\n", - " comb_val_X,\n", - " comb_test_X,\n", - ") = splits" - ] - }, - { - "cell_type": "markdown", - "id": "e1c2c5b5-9979-4678-bead-ad42b547abfb", - "metadata": { - "tags": [] - }, - "source": [ - "## Save" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9f88e2b1-373e-46c6-97ca-aabd2873fe9a", - "metadata": {}, - "outputs": [], - "source": [ - "# Store data (serialize)\n", - "vectorized = [\n", - " (tab_train_X, \"tab_train_X\"),\n", - " (tab_train_y, \"tab_train_y\"),\n", - " (tab_val_X, \"tab_val_X\"),\n", - " (tab_val_y, \"tab_val_y\"),\n", - " (tab_test_X, \"tab_test_X\"),\n", - " (tab_test_y, \"tab_test_y\"),\n", - " (temp_train_X, \"temp_train_X\"),\n", - " (temp_train_y, \"temp_train_y\"),\n", - " (temp_val_X, \"temp_val_X\"),\n", - " (temp_val_y, \"temp_val_y\"),\n", - " (temp_test_X, \"temp_test_X\"),\n", - " (temp_test_y, \"temp_test_y\"),\n", - " (comb_train_X, \"comb_train_X\"),\n", - " (comb_train_y, \"comb_train_y\"),\n", - " (comb_val_X, \"comb_val_X\"),\n", - " (comb_val_y, \"comb_val_y\"),\n", - " (comb_test_X, \"comb_test_X\"),\n", - " (comb_test_y, \"comb_test_y\"),\n", - "]\n", - "for vec, name in vectorized:\n", - " save_pickle(vec, use_case_params.TAB_VEC_COMB + name + \".pkl\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "cyclops-KKtuQLwg-py3.9", - "language": "python", - "name": "cyclops-kktuqlwg-py3.9" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/nbs/monitor/mortality/static_model.ipynb b/nbs/monitor/mortality/static_model.ipynb deleted file mode 100644 index f3c352a40..000000000 --- a/nbs/monitor/mortality/static_model.ipynb +++ /dev/null @@ -1,467 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "cd4e001b-4921-4a4f-8c2c-0c364d082a3f", - "metadata": {}, - "source": [ - "### Train static models for in hospital mortality risk prediction" - ] - }, - { - "cell_type": "markdown", - "id": "d4681dab-9200-4063-9544-f2013bdcd2ba", - "metadata": {}, - "source": [ - "## Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cce5e9ec-e3d0-4719-a7ee-43e88dcde8b2", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import pickle\n", - "import random\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import scipy.stats as st\n", - "from drift_detection.baseline_models.static.utils import run_model\n", - "from drift_detection.gemini.utils import (\n", - " get_label,\n", - " import_dataset_hospital,\n", - " normalize,\n", - " process,\n", - " scale,\n", - " unison_shuffled_copies,\n", - ")\n", - "from sklearn.metrics import ( # accuracy_score,; confusion_matrix,; roc_auc_score,\n", - " auc,\n", - " average_precision_score,\n", - " precision_recall_curve,\n", - " roc_curve,\n", - ")\n", - "from use_cases.common.util import get_use_case_params\n", - "\n", - "from cyclops.monitor.plotter import (\n", - " brightness,\n", - " colors,\n", - " colorscale,\n", - " errorfill,\n", - " linestyles,\n", - " markers,\n", - " plot_pr,\n", - " plot_roc,\n", - ")\n", - "\n", - "\n", - "get_gemini_data = callable()" - ] - }, - { - "cell_type": "markdown", - "id": "25c975e2-2e68-4283-ad78-40c7e866be3f", - "metadata": {}, - "source": [ - "## Parameters" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "30b4d3ab-f6a9-459a-b755-79160cde4cf4", - "metadata": {}, - "outputs": [], - "source": [ - "PATH = \"/mnt/nfs/project/delirium/drift_exp/JULY-04-2022/\"\n", - "TIMESTEPS = 6\n", - "AGGREGATION_TYPE = \"time_flatten\"\n", - "HOSPITALS = [\"SMH\", \"MSH\", \"THPC\", \"THPM\", \"UHNTG\", \"UHNTW\", \"PMH\", \"SBK\"]\n", - "OUTCOME = \"mortality\"\n", - "THRESHOLD = 0.05" - ] - }, - { - "cell_type": "markdown", - "id": "4177589f-dd6a-4ee1-af9e-124ee35e3d45", - "metadata": {}, - "source": [ - "## Query and process data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0bfdbbe6-c047-4f8d-89ef-d25c98a7144a", - "metadata": {}, - "outputs": [], - "source": [ - "DATASET = \"gemini\"\n", - "USE_CASE = \"mortality_decompensation\"\n", - "\n", - "use_case_params = get_use_case_params(DATASET, USE_CASE)\n", - "input(f\"WARNING: LOADING CONSTANTS FROM {use_case_params}\")" - ] - }, - { - "cell_type": "markdown", - "id": "00ed9a03-3746-445d-80a5-11c81020645b", - "metadata": {}, - "source": [ - "## Legacy Scripts" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "293b766a-27f7-4240-afd5-0004c5bff89e", - "metadata": {}, - "outputs": [], - "source": [ - "SHIFT = input(\"Select experiment: \") # hospital_type\n", - "\n", - "admin_data, x, y = get_gemini_data(PATH)\n", - "\n", - "(X_tr, y_tr), (X_val, y_val), (X_t, y_t), feats, admin_data = import_dataset_hospital(\n", - " admin_data,\n", - " x,\n", - " y,\n", - " SHIFT,\n", - " OUTCOME,\n", - " HOSPITALS,\n", - ")\n", - "\n", - "# Normalize data\n", - "X_tr_normalized = normalize(admin_data, X_tr, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_val_normalized = normalize(admin_data, X_val, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_normalized = normalize(admin_data, X_t, AGGREGATION_TYPE, TIMESTEPS)\n", - "\n", - "y_tr = get_label(admin_data, X_tr, OUTCOME)\n", - "y_val = get_label(admin_data, X_val, OUTCOME)\n", - "y_t = get_label(admin_data, X_t, OUTCOME)\n", - "\n", - "# Scale data\n", - "X_tr_scaled = scale(X_tr_normalized)\n", - "X_val_scaled = scale(X_val_normalized)\n", - "X_t_scaled = scale(X_t_normalized)\n", - "\n", - "# Process data\n", - "X_tr_final = process(X_tr_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_val_final = process(X_val_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_final = process(X_t_scaled, AGGREGATION_TYPE, TIMESTEPS)" - ] - }, - { - "cell_type": "markdown", - "id": "f8458f04-d058-465a-9356-e2a9237cf2b7", - "metadata": {}, - "source": [ - "## Build Model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "57fbae00-ba0f-4aee-8eaa-1f241b2e42f9", - "metadata": {}, - "outputs": [], - "source": [ - "MODEL_NAME = input(\"Select Model: \")\n", - "MODEL_PATH = PATH + \"_\".join([SHIFT, OUTCOME, \"_\".join(HOSPITALS), MODEL_NAME]) + \".pkl\"\n", - "if os.path.exists(MODEL_PATH):\n", - " optimised_model = pickle.load(open(MODEL_PATH, \"rb\"))\n", - "else:\n", - " optimised_model = run_model(MODEL_NAME, X_tr_final, y_tr, X_val_final, y_val)\n", - " pickle.dump(optimised_model, open(MODEL_PATH, \"wb\"))" - ] - }, - { - "cell_type": "markdown", - "id": "bd7a7b18-0dfa-4e9a-a080-5e995365036b", - "metadata": {}, - "source": [ - "## Performance ##" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bffbadf3-460d-4035-9d96-6a2dfd875734", - "metadata": {}, - "outputs": [], - "source": [ - "val_auroc = []\n", - "val_auprc = []\n", - "test_auroc = []\n", - "test_auprc = []\n", - "\n", - "RANDOM_RUNS = 10\n", - "for i in range(RANDOM_RUNS):\n", - " random.seed(i)\n", - "\n", - " (\n", - " (X_tr, y_tr),\n", - " (X_val, y_val),\n", - " (X_t, y_t),\n", - " feats,\n", - " admin_data,\n", - " ) = import_dataset_hospital(admin_data, x, y, SHIFT, OUTCOME, HOSPITALS, i)\n", - "\n", - " # Normalize data\n", - " X_tr_normalized = normalize(admin_data, X_tr, AGGREGATION_TYPE, TIMESTEPS)\n", - " X_val_normalized = normalize(admin_data, X_val, AGGREGATION_TYPE, TIMESTEPS)\n", - " X_t_normalized = normalize(admin_data, X_t, AGGREGATION_TYPE, TIMESTEPS)\n", - "\n", - " # Get labels\n", - " y_tr = get_label(admin_data, X_tr, OUTCOME)\n", - " y_val = get_label(admin_data, X_val, OUTCOME)\n", - " y_t = get_label(admin_data, X_t, OUTCOME)\n", - "\n", - " # Scale data\n", - " X_tr_scaled = scale(X_tr_normalized)\n", - " X_val_scaled = scale(X_val_normalized)\n", - " X_t_scaled = scale(X_t_normalized)\n", - "\n", - " # Process data\n", - " X_tr_final = process(X_tr_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - " X_val_final = process(X_val_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - " X_t_final = process(X_t_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "\n", - " y_pred_prob = optimised_model.predict_proba(X_val_final)[:, 1]\n", - " fpr, tpr, thresholds = roc_curve(y_val, y_pred_prob, pos_label=1)\n", - " roc_auc = auc(fpr, tpr)\n", - " val_auroc.append(roc_auc)\n", - " precision, recall, thresholds = precision_recall_curve(y_val, y_pred_prob)\n", - " auc_pr = auc(recall, precision)\n", - " val_auprc.append(auc_pr)\n", - "\n", - " y_pred_prob = optimised_model.predict_proba(X_t_final)[:, 1]\n", - " fpr, tpr, thresholds = roc_curve(y_t, y_pred_prob, pos_label=1)\n", - " roc_auc = auc(fpr, tpr)\n", - " test_auroc.append(roc_auc)\n", - " precision, recall, thresholds = precision_recall_curve(y_t, y_pred_prob)\n", - " auc_pr = auc(recall, precision)\n", - " test_auprc.append(auc_pr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "edf088df-e1a3-4327-94f6-fd143ef27a9d", - "metadata": {}, - "outputs": [], - "source": [ - "def get_bootstrapped_metric(bs_metric):\n", - " bs_mean = np.round(np.mean(bs_metric), 3)\n", - " ci = st.t.interval(\n", - " 0.95,\n", - " len(bs_metric) - 1,\n", - " loc=np.mean(bs_metric),\n", - " scale=st.sem(bs_metric),\n", - " )\n", - " return (\n", - " str(bs_mean)\n", - " + \" [\"\n", - " + str(np.round(ci[0], 3))\n", - " + \" - \"\n", - " + str(np.round(ci[1], 3))\n", - " + \"]\"\n", - " )\n", - "\n", - "\n", - "val_auroc_bs = get_bootstrapped_metric(val_auroc)\n", - "val_auprc_bs = get_bootstrapped_metric(val_auprc)\n", - "test_auroc_bs = get_bootstrapped_metric(test_auroc)\n", - "test_auprc_bs = get_bootstrapped_metric(test_auprc)" - ] - }, - { - "cell_type": "markdown", - "id": "cee4cd21-4648-438b-bbf6-248090963c21", - "metadata": {}, - "source": [ - "### Performance on Source Data ###" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "resistant-dover", - "metadata": {}, - "outputs": [], - "source": [ - "y_pred_prob = optimised_model.predict_proba(X_val_final)[:, 1]\n", - "\n", - "fpr, tpr, thresholds = roc_curve(y_val, y_pred_prob, pos_label=1)\n", - "precision, recall, thresholds = precision_recall_curve(y_val, y_pred_prob)\n", - "fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(14, 6))\n", - "plot_roc(ax[0], fpr, tpr, val_auroc_bs)\n", - "plot_pr(ax[1], recall, precision, val_auprc_bs)" - ] - }, - { - "cell_type": "markdown", - "id": "aa8ecd0a-afd2-47f9-97bc-424b8646bcb6", - "metadata": {}, - "source": [ - "### Performance on Target Data ###" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "affecting-marriage", - "metadata": {}, - "outputs": [], - "source": [ - "y_pred_prob = optimised_model.predict_proba(X_t_final)[:, 1]\n", - "fpr, tpr, thresholds = roc_curve(y_t, y_pred_prob, pos_label=1)\n", - "precision, recall, thresholds = precision_recall_curve(y_t, y_pred_prob)\n", - "fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(14, 6))\n", - "plot_roc(ax[0], fpr, tpr, test_auroc_bs)\n", - "plot_pr(ax[1], recall, precision, test_auprc_bs)" - ] - }, - { - "cell_type": "markdown", - "id": "dde55789-cb88-49db-9121-27cbdbddc46b", - "metadata": {}, - "source": [ - "## Get AUROC and AUPRC By Varying Sample Sizes in Test Set" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "394030e8-1387-4f57-847d-2c982c7787ee", - "metadata": {}, - "outputs": [], - "source": [ - "SAMPLES = [10, 20, 50, 100, 200, 500, 1000]\n", - "\n", - "RANDOM_RUNS = 100\n", - "samp_metrics = np.ones((len(SAMPLES), RANDOM_RUNS, 2, 2)) * (-1)\n", - "for si, sample in enumerate(SAMPLES):\n", - " for i in range(0, RANDOM_RUNS - 1):\n", - " i = int(i)\n", - " np.random.seed(i)\n", - " X_val_shuffled, y_val_shuffled = unison_shuffled_copies(X_val_final, y_val)\n", - " X_test_shuffled, y_test_shuffled = unison_shuffled_copies(X_t_final, y_t)\n", - "\n", - " y_val_pred_prob = optimised_model.predict_proba(X_val_shuffled[:sample])[:, 1]\n", - " val_fpr, val_tpr, val_thresholds = roc_curve(\n", - " y_val_shuffled[:sample],\n", - " y_val_pred_prob[:sample],\n", - " pos_label=1,\n", - " )\n", - " val_roc_auc = auc(val_fpr, val_tpr)\n", - " val_avg_pr = average_precision_score(\n", - " y_val_shuffled[:sample],\n", - " y_val_pred_prob[:sample],\n", - " )\n", - "\n", - " y_test_pred_prob = optimised_model.predict_proba(X_test_shuffled[:sample])[:, 1]\n", - " test_fpr, test_tpr, test_thresholds = roc_curve(\n", - " y_test_shuffled[:sample],\n", - " y_test_pred_prob[:sample],\n", - " pos_label=1,\n", - " )\n", - " test_roc_auc = auc(test_fpr, test_tpr)\n", - " test_avg_pr = average_precision_score(\n", - " y_test_shuffled[:sample],\n", - " y_test_pred_prob[:sample],\n", - " )\n", - "\n", - " samp_metrics[si, i, 0, :] = [val_roc_auc, val_avg_pr]\n", - " samp_metrics[si, i, 1, :] = [test_roc_auc, test_avg_pr]\n", - "\n", - " mean_samp_metrics = np.mean(samp_metrics, axis=1)\n", - " std_samp_metrics = np.std(samp_metrics, axis=1)" - ] - }, - { - "cell_type": "markdown", - "id": "e742c540-ab70-40cd-8f98-3b899e538920", - "metadata": {}, - "source": [ - "## Plot Performance By Varying Sample Sizes in Test Set" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bfe0ec5d-bfb0-463f-9461-82533fddc6de", - "metadata": {}, - "outputs": [], - "source": [ - "fig = plt.figure(figsize=(8, 6))\n", - "for si, shift in enumerate([\"baseline\", SHIFT]):\n", - " errorfill(\n", - " np.array(SAMPLES[1:]),\n", - " mean_samp_metrics[1:, si, 0],\n", - " std_samp_metrics[1:, si, 0],\n", - " fmt=linestyles[si] + markers[si],\n", - " color=colorscale(colors[si], brightness[si]),\n", - " label=\"%s\" % \"_\".join([shift]),\n", - " )\n", - "plt.xlabel(\"Number of samples from test data\")\n", - "plt.ylabel(\"AUROC\")\n", - "plt.legend()\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b9a41070-62be-4d88-8b7f-08188a1ac828", - "metadata": {}, - "outputs": [], - "source": [ - "fig = plt.figure(figsize=(8, 6))\n", - "for si, shift in enumerate([\"baseline\", SHIFT]):\n", - " errorfill(\n", - " np.array(SAMPLES[1:]),\n", - " mean_samp_metrics[1:, si, 1],\n", - " std_samp_metrics[1:, si, 1],\n", - " fmt=linestyles[si] + markers[si],\n", - " color=colorscale(colors[si], brightness[si]),\n", - " label=\"%s\" % \"_\".join([shift]),\n", - " )\n", - "plt.xlabel(\"Number of samples from test data\")\n", - "plt.ylabel(\"AUPRC\")\n", - "plt.legend()\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.9.7 ('pycyclops-4J2PL5I8-py3.9')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - }, - "vscode": { - "interpreter": { - "hash": "a3020bd91ee2a3fe37ba2e4a754058255d6b04fc00c4b4bebbda2c828f5bd9d4" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/nbs/monitor/mortality/temporal_model.ipynb b/nbs/monitor/mortality/temporal_model.ipynb deleted file mode 100644 index 38f724aa6..000000000 --- a/nbs/monitor/mortality/temporal_model.ipynb +++ /dev/null @@ -1,636 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "cd4e001b-4921-4a4f-8c2c-0c364d082a3f", - "metadata": {}, - "source": [ - "## Train temporal models for mortality risk prediction" - ] - }, - { - "cell_type": "markdown", - "id": "d4681dab-9200-4063-9544-f2013bdcd2ba", - "metadata": {}, - "source": [ - "## Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cce5e9ec-e3d0-4719-a7ee-43e88dcde8b2", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import random\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "import plotly.graph_objects as go\n", - "import seaborn as sns\n", - "import torch\n", - "from drift_detection.baseline_models.temporal.pytorch.optimizer import Optimizer\n", - "from drift_detection.baseline_models.temporal.pytorch.utils import (\n", - " get_data,\n", - " get_device,\n", - " get_temporal_model,\n", - " print_metrics_binary,\n", - ")\n", - "from drift_detection.gemini.utils import prep\n", - "from sklearn import metrics\n", - "from torch import nn, optim\n", - "from use_cases.common.util import get_use_case_params\n", - "\n", - "from cyclops.utils.file import load_pickle" - ] - }, - { - "cell_type": "markdown", - "id": "89e583ae-7a70-49d1-9e93-26b692c85d89", - "metadata": {}, - "source": [ - "## Load train/val/test inputs and labels" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f883f5ca-8d8c-4ec2-878d-0e498f5d7d93", - "metadata": {}, - "outputs": [], - "source": [ - "DATASET = \"gemini\"\n", - "USE_CASE = \"mortality\"\n", - "\n", - "use_case_params = get_use_case_params(DATASET, USE_CASE)\n", - "input(f\"WARNING: LOADING CONSTANTS FROM {use_case_params}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12066550-01d2-4869-a96b-2ab9c37171ae", - "metadata": {}, - "outputs": [], - "source": [ - "X_train_vec = load_pickle(use_case_params.TAB_VEC_COMB + \"comb_train_X\")\n", - "y_train_vec = load_pickle(use_case_params.TAB_VEC_COMB + \"comb_train_y\")\n", - "X_val_vec = load_pickle(use_case_params.TAB_VEC_COMB + \"comb_val_X\")\n", - "y_val_vec = load_pickle(use_case_params.TAB_VEC_COMB + \"comb_val_y\")\n", - "X_test_vec = load_pickle(use_case_params.TAB_VEC_COMB + \"comb_test_X\")\n", - "y_test_vec = load_pickle(use_case_params.TAB_VEC_COMB + \"comb_test_y\")\n", - "\n", - "X_train = prep(X_train_vec.data)\n", - "y_train = prep(y_train_vec.data)\n", - "X_val = prep(X_val_vec.data)\n", - "y_val = prep(y_val_vec.data)\n", - "X_test = prep(X_test_vec.data)\n", - "y_test = prep(y_test_vec.data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fcb16cf4-7838-4757-850e-913093c42181", - "metadata": {}, - "outputs": [], - "source": [ - "unique, train_counts = np.unique(y_train, return_counts=True)\n", - "unique, val_counts = np.unique(y_val, return_counts=True)\n", - "unique, test_counts = np.unique(y_test, return_counts=True)\n", - "print(\n", - " pd.DataFrame(\n", - " {\"Train\": train_counts, \"Val\": val_counts, \"Test\": test_counts},\n", - " index=unique,\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a2bc49ca-66ab-4c27-ba5a-a6c6260793c9", - "metadata": {}, - "outputs": [], - "source": [ - "batch_size = 64\n", - "train_dataset = get_data(X_train, y_train)\n", - "train_loader = train_dataset.to_loader(batch_size, shuffle=True)\n", - "\n", - "val_dataset = get_data(X_val, y_val)\n", - "val_loader = val_dataset.to_loader(batch_size)\n", - "\n", - "test_dataset = get_data(X_test, y_test)\n", - "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)" - ] - }, - { - "cell_type": "markdown", - "id": "dde55789-cb88-49db-9121-27cbdbddc46b", - "metadata": {}, - "source": [ - "## Model and training configuration" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "810aab27-5fc4-4271-98eb-16f0917eb520", - "metadata": {}, - "outputs": [], - "source": [ - "output_dim = 1\n", - "batch_size = 64\n", - "input_dim = X_train.shape[2]\n", - "timesteps = X_train.shape[1]\n", - "hidden_dim = 64\n", - "layer_dim = 2\n", - "dropout = 0.2\n", - "n_epochs = 256\n", - "learning_rate = 2e-3\n", - "weight_decay = 1e-6\n", - "last_timestep_only = False\n", - "\n", - "device = get_device()\n", - "\n", - "X_train_inputs = X_train\n", - "X_val_inputs = X_val\n", - "X_test_inputs = X_test\n", - "\n", - "train_dataset = get_data(X_train_inputs, y_train)\n", - "train_loader = train_dataset.to_loader(batch_size, shuffle=True)\n", - "\n", - "val_dataset = get_data(X_val_inputs, y_val)\n", - "val_loader = val_dataset.to_loader(batch_size)\n", - "\n", - "model_params = {\n", - " \"device\": device,\n", - " \"input_dim\": input_dim,\n", - " \"hidden_dim\": hidden_dim,\n", - " \"layer_dim\": layer_dim,\n", - " \"output_dim\": output_dim,\n", - " \"dropout_prob\": dropout,\n", - " \"last_timestep_only\": last_timestep_only,\n", - "}\n", - "split_type = None\n", - "model = get_temporal_model(\"lstm\", model_params).to(device)\n", - "os.chdir(os.path.join(os.getcwd(), \"../../saved_models\"))\n", - "model_path = os.path.join(os.getcwd(), split_type + \"_lstm.pt\")\n", - "model.load_state_dict(torch.load(model_path))" - ] - }, - { - "cell_type": "markdown", - "id": "388af42e-49ac-48c5-a285-8ce66a714518", - "metadata": {}, - "source": [ - "## Training and validation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "89bdb2aa-8cb2-4632-b6bd-753033903716", - "metadata": {}, - "outputs": [], - "source": [ - "loss_fn = nn.BCEWithLogitsLoss(reduction=\"none\")\n", - "optimizer = optim.Adagrad(\n", - " model.parameters(),\n", - " lr=learning_rate,\n", - " weight_decay=weight_decay,\n", - ")\n", - "lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=128, gamma=0.5)\n", - "activation = nn.Sigmoid()\n", - "opt = Optimizer(\n", - " model=model,\n", - " loss_fn=loss_fn,\n", - " optimizer=optimizer,\n", - " activation=activation,\n", - " lr_scheduler=lr_scheduler,\n", - ")\n", - "opt.train(\n", - " train_loader,\n", - " val_loader,\n", - " batch_size=batch_size,\n", - " n_epochs=n_epochs,\n", - " n_features=input_dim,\n", - " timesteps=timesteps,\n", - ")\n", - "opt.plot_losses()" - ] - }, - { - "cell_type": "markdown", - "id": "21c88675-3fb8-4bd8-9022-fc16a250192a", - "metadata": {}, - "source": [ - "## Validation metrics" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5a631838-e434-4005-8879-8d60b062ac1c", - "metadata": {}, - "outputs": [], - "source": [ - "val_evaluate_loader = torch.utils.data.DataLoader(\n", - " val_dataset,\n", - " batch_size=1,\n", - " shuffle=False,\n", - ")\n", - "y_val_labels, y_val_pred_values, y_val_pred_labels = opt.evaluate(\n", - " val_evaluate_loader,\n", - " batch_size=1,\n", - " n_features=input_dim,\n", - " timesteps=timesteps,\n", - ")\n", - "\n", - "y_val_pred_values = y_val_pred_values[y_val_labels != -1]\n", - "y_val_pred_labels = y_val_pred_labels[y_val_labels != -1]\n", - "y_val_labels = y_val_labels[y_val_labels != -1]\n", - "\n", - "confusion_matrix = metrics.confusion_matrix(y_val_labels, y_val_pred_labels)\n", - "print(confusion_matrix)\n", - "\n", - "pred_metrics = print_metrics_binary(y_val_labels, y_val_pred_values, y_val_pred_labels)\n", - "prec = (pred_metrics[\"prec0\"] + pred_metrics[\"prec1\"]) / 2\n", - "rec = (pred_metrics[\"rec0\"] + pred_metrics[\"rec1\"]) / 2\n", - "print(f\"Precision: {prec}\")\n", - "print(f\"Recall: {rec}\")\n", - "\n", - "\n", - "def plot_pretty_confusion_matrix(confusion_matrix):\n", - " sns.set(style=\"white\")\n", - " fig, ax = plt.subplots(figsize=(9, 6))\n", - " sns.heatmap(\n", - " np.eye(2),\n", - " annot=confusion_matrix,\n", - " fmt=\"g\",\n", - " annot_kws={\"size\": 50},\n", - " cmap=sns.color_palette([\"tomato\", \"palegreen\"], as_cmap=True),\n", - " cbar=False,\n", - " yticklabels=[\"False\", \"True\"],\n", - " xticklabels=[\"False\", \"True\"],\n", - " ax=ax,\n", - " )\n", - " ax.xaxis.tick_top()\n", - " ax.xaxis.set_label_position(\"top\")\n", - " ax.tick_params(labelsize=20, length=0)\n", - "\n", - " ax.set_title(\"Confusion Matrix for Test Set\", size=24, pad=20)\n", - " ax.set_xlabel(\"Predicted Values\", size=20)\n", - " ax.set_ylabel(\"Actual Values\", size=20)\n", - "\n", - " additional_texts = [\n", - " \"(True Negative)\",\n", - " \"(False Negative)\",\n", - " \"(False Positive)\",\n", - " \"(True Positive)\",\n", - " ]\n", - " for text_elt, additional_text in zip(ax.texts, additional_texts):\n", - " ax.text(\n", - " *text_elt.get_position(),\n", - " \"\\n\" + additional_text,\n", - " color=text_elt.get_color(),\n", - " ha=\"center\",\n", - " va=\"top\",\n", - " size=24,\n", - " )\n", - " plt.tight_layout()\n", - " plt.show()\n", - "\n", - "\n", - "plot_pretty_confusion_matrix(confusion_matrix)" - ] - }, - { - "cell_type": "markdown", - "id": "9a11dbaa-02e3-4415-84ee-abc159a5337a", - "metadata": {}, - "source": [ - "## Testing metrics" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "14f74137-9a38-40c6-92ee-1e6f4e02290c", - "metadata": {}, - "outputs": [], - "source": [ - "test_dataset = get_data(X_test_inputs, y_test)\n", - "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)\n", - "y_test_labels, y_pred_values, y_pred_labels = opt.evaluate(\n", - " test_loader,\n", - " batch_size=1,\n", - " n_features=input_dim,\n", - " timesteps=timesteps,\n", - ")\n", - "\n", - "y_pred_values = y_pred_values[y_test_labels != -1]\n", - "y_pred_labels = y_pred_labels[y_test_labels != -1]\n", - "y_test_labels = y_test_labels[y_test_labels != -1]\n", - "\n", - "confusion_matrix = metrics.confusion_matrix(y_test_labels, y_pred_labels)\n", - "print(confusion_matrix)\n", - "\n", - "pred_metrics = print_metrics_binary(y_test_labels, y_pred_values, y_pred_labels)\n", - "prec = (pred_metrics[\"prec0\"] + pred_metrics[\"prec1\"]) / 2\n", - "rec = (pred_metrics[\"rec0\"] + pred_metrics[\"rec1\"]) / 2\n", - "print(f\"Precision: {prec}\")\n", - "print(f\"Recall: {rec}\")" - ] - }, - { - "cell_type": "markdown", - "id": "612d8bf4-79c5-449f-ad60-f20217377827", - "metadata": {}, - "source": [ - "## Plot confusion matrix" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "449a9b82-1081-4140-814f-a5ce083662ff", - "metadata": {}, - "outputs": [], - "source": [ - "def plot_confusion_matrix(confusion_matrix, class_names):\n", - " confusion_matrix = (\n", - " confusion_matrix.astype(\"float\") / confusion_matrix.sum(axis=1)[:, np.newaxis]\n", - " )\n", - "\n", - " layout = {\n", - " \"title\": \"Confusion Matrix\",\n", - " \"xaxis\": {\"title\": \"Predicted value\"},\n", - " \"yaxis\": {\"title\": \"Real value\"},\n", - " }\n", - "\n", - " fig = go.Figure(\n", - " data=go.Heatmap(\n", - " z=confusion_matrix,\n", - " x=class_names,\n", - " y=class_names,\n", - " hoverongaps=False,\n", - " colorscale=\"Greens\",\n", - " ),\n", - " layout=layout,\n", - " )\n", - " fig.update_layout(height=512, width=1024)\n", - " fig.show()\n", - "\n", - "\n", - "plot_confusion_matrix(\n", - " confusion_matrix,\n", - " [\"low risk of mortality\", \"high risk of mortality\"],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "a43cfde9-3c6d-4de9-817e-811376f5cb3f", - "metadata": {}, - "source": [ - "## Compute AUROC across timesteps" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "679e86a5-8db4-4e44-bf50-babfb186427b", - "metadata": {}, - "outputs": [], - "source": [ - "y_test_labels, y_pred_values, y_pred_labels = opt.evaluate(\n", - " test_loader,\n", - " batch_size=1,\n", - " n_features=input_dim,\n", - " timesteps=timesteps,\n", - " flatten=False,\n", - ")\n", - "\n", - "num_timesteps = y_pred_labels.shape[1]\n", - "auroc_timesteps = []\n", - "for i in range(num_timesteps):\n", - " labels = y_test_labels[:, i]\n", - " pred_vals = y_pred_values[:, i]\n", - " preds = y_pred_labels[:, i]\n", - " pred_vals = pred_vals[labels != -1]\n", - " preds = preds[labels != -1]\n", - " labels = labels[labels != -1]\n", - " pred_metrics = print_metrics_binary(labels, pred_vals, preds, verbose=False)\n", - " auroc_timesteps.append(pred_metrics[\"auroc\"])\n", - "\n", - "\n", - "prediction_hours = list(range(24, 168, 24))\n", - "fig = go.Figure(\n", - " data=[go.Bar(x=prediction_hours, y=auroc_timesteps, name=\"model confidence\")],\n", - ")\n", - "\n", - "fig.update_xaxes(tickvals=prediction_hours)\n", - "fig.update_yaxes(range=[min(auroc_timesteps) - 0.05, max(auroc_timesteps) + 0.05])\n", - "\n", - "fig.update_layout(\n", - " title=\"AUROC split by no. of hours after admission\",\n", - " autosize=False,\n", - " xaxis_title=\"No. of hours after admission\",\n", - ")\n", - "fig.show()" - ] - }, - { - "cell_type": "markdown", - "id": "1aea19ab-7dc3-4f9e-b073-17f788537b81", - "metadata": {}, - "source": [ - "## WIP: Compute accuracy across lead times" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5ed54c9c-4e81-4bce-9fe7-09a18296c3de", - "metadata": {}, - "outputs": [], - "source": [ - "# timestep_end_timestamps = load_dataframe(os.path.join(BASE_DATA_PATH,\n", - "# \"aggmeta_end_ts\"))\n", - "\n", - "\n", - "# train_val_test_ids = load_dataframe(os.path.join(BASE_DATA_PATH,\n", - "# \"train_val_test_ids\"))\n", - "\n", - "# for timestep in range(num_timesteps):\n", - "\n", - "# for enc_id in test_ids:\n", - "# mortality_timestamp = mortality_events.loc[mortality_events[\"encounter_id\"]\n", - "# == enc_id][\"discharge_timestamp\"]\n", - "# if (lead_time > pd.to_timedelta(0, unit=\"h\")).all():\n", - "\n", - "# if label_ == 1:\n", - "# if label_ == pred_:" - ] - }, - { - "cell_type": "markdown", - "id": "d8825f3a-a941-456b-b418-db620c224eb0", - "metadata": {}, - "source": [ - "## Visualize model outputs and labels" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ee6fbfab-57bc-4f07-9645-740b588d2ea2", - "metadata": {}, - "outputs": [], - "source": [ - "def plot_risk_mortality(predictions, labels=None):\n", - " prediction_hours = list(range(24, 168, 24))\n", - " is_mortality = labels == 1\n", - " after_discharge = labels == -1\n", - " label_h = -0.2\n", - " fig = go.Figure(\n", - " data=[\n", - " go.Scatter(\n", - " mode=\"markers\",\n", - " x=prediction_hours,\n", - " y=[label_h for x in prediction_hours],\n", - " line={\"color\": \"Black\"},\n", - " name=\"low risk of mortality label\",\n", - " marker={\n", - " \"color\": \"Green\",\n", - " \"size\": 20,\n", - " \"line\": {\"color\": \"Black\", \"width\": 2},\n", - " },\n", - " ),\n", - " go.Scatter(\n", - " mode=\"markers\",\n", - " x=[prediction_hours[i] for i, v in enumerate(is_mortality) if v],\n", - " y=[label_h for _, v in enumerate(is_mortality) if v],\n", - " line={\"color\": \"Red\"},\n", - " name=\"high risk of mortality label\",\n", - " marker={\n", - " \"color\": \"Red\",\n", - " \"size\": 20,\n", - " \"line\": {\"color\": \"Black\", \"width\": 2},\n", - " },\n", - " ),\n", - " go.Scatter(\n", - " mode=\"markers\",\n", - " x=[prediction_hours[i] for i, v in enumerate(after_discharge) if v],\n", - " y=[label_h for _, v in enumerate(after_discharge) if v],\n", - " line={\"color\": \"Grey\"},\n", - " name=\"post discharge label\",\n", - " marker={\n", - " \"color\": \"Grey\",\n", - " \"size\": 20,\n", - " \"line\": {\"color\": \"Black\", \"width\": 2},\n", - " },\n", - " ),\n", - " go.Bar(\n", - " x=prediction_hours,\n", - " y=predictions,\n", - " marker_color=\"Red\",\n", - " name=\"model confidence\",\n", - " ),\n", - " ],\n", - " )\n", - " fig.update_yaxes(range=[label_h, 1])\n", - " fig.update_xaxes(tickvals=prediction_hours)\n", - " fig.update_xaxes(showline=True, linewidth=2, linecolor=\"black\")\n", - "\n", - " fig.add_hline(y=0.5)\n", - "\n", - " fig.update_layout(\n", - " title=\"Model output visualization\",\n", - " autosize=False,\n", - " xaxis_title=\"No. of hours after admission\",\n", - " yaxis_title=\"Model confidence\",\n", - " )\n", - "\n", - " return fig\n", - "\n", - "\n", - "mortality_cases = [idx for idx, v in enumerate(y_test_labels)]\n", - "sample_idx = random.choice(mortality_cases)\n", - "fig = plot_risk_mortality(\n", - " y_pred_values[sample_idx].squeeze(),\n", - " y_test_labels[sample_idx],\n", - ")\n", - "fig.show()" - ] - }, - { - "cell_type": "markdown", - "id": "e8e182f4-e943-4e8f-ac45-8c445ed064c5", - "metadata": {}, - "source": [ - "## Journal of some experiments" - ] - }, - { - "cell_type": "markdown", - "id": "9cb21dd9-83a7-4b49-b328-655bf02a3c8b", - "metadata": {}, - "source": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " ... \n", - " \n", - "
SplitModelAUROC
Random
LSTM0.8005
" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.9.7 ('pycyclops-4J2PL5I8-py3.9')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - }, - "vscode": { - "interpreter": { - "hash": "a3020bd91ee2a3fe37ba2e4a754058255d6b04fc00c4b4bebbda2c828f5bd9d4" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/nbs/monitor/retraining/confidence.ipynb b/nbs/monitor/retraining/confidence.ipynb deleted file mode 100644 index faa215adb..000000000 --- a/nbs/monitor/retraining/confidence.ipynb +++ /dev/null @@ -1,693 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "7456e505-ef1f-4537-b819-f9f433271ed7", - "metadata": {}, - "source": [ - "### Retraining, removing low confidence samples" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d995dcd3-fb7a-4398-b8b8-984c2ddc78be", - "metadata": {}, - "outputs": [], - "source": [ - "import datetime\n", - "import os\n", - "import pickle\n", - "import random\n", - "from datetime import date\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "import scipy.stats as st\n", - "import torch\n", - "from baseline_models.temporal.pytorch.metrics import print_metrics_binary\n", - "from baseline_models.temporal.pytorch.optimizer import Optimizer\n", - "from baseline_models.temporal.pytorch.utils import get_data, get_device\n", - "from drift_detector.utils import get_temporal_model\n", - "from gemini.query import get_gemini_data\n", - "from gemini.utils import import_dataset_hospital\n", - "from matplotlib.colors import ListedColormap\n", - "from torch import nn, optim" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17cff9cf-c682-4973-ab27-1161b70d516a", - "metadata": {}, - "outputs": [], - "source": [ - "PATH = \"/mnt/nfs/project/delirium/drift_exp/JULY-04-2022\"\n", - "threshold = 0.05\n", - "num_timesteps = 6\n", - "run = 1\n", - "shift = \"covid\"\n", - "hospital = [\"SBK\", \"UHNTG\", \"THPC\", \"THPM\", \"UHNTW\", \"SMH\", \"MSH\", \"PMH\"]\n", - "outcome = \"mortality\"\n", - "aggregation_type = \"time\"\n", - "scale = True" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "93113177-c6d6-44af-9f74-0c043c781416", - "metadata": {}, - "outputs": [], - "source": [ - "scale_temporal_data = callable()\n", - "reshape_inputs = callable()\n", - "admin_data, x, y = get_gemini_data(PATH)\n", - "x = scale_temporal_data(x)\n", - "X = reshape_inputs(x, num_timesteps)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "07d91c48-3b72-4510-bfbe-82f70193090c", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " (x_train, y_train),\n", - " (x_val, y_val),\n", - " (x_test, y_test),\n", - " feats,\n", - " admin_data,\n", - ") = import_dataset_hospital(\n", - " admin_data,\n", - " x,\n", - " y,\n", - " shift,\n", - " outcome,\n", - " hospital,\n", - " run,\n", - " shuffle=True,\n", - ")\n", - "\n", - "random.seed(1)\n", - "\n", - "# Normalize data\n", - "normalize_data = callable()\n", - "(\n", - " (X_tr_normalized, y_tr),\n", - " (X_val_normalized, y_val),\n", - " (X_t_normalized, y_t),\n", - ") = normalize_data(\n", - " aggregation_type,\n", - " admin_data,\n", - " num_timesteps,\n", - " x_train,\n", - " y_train,\n", - " x_val,\n", - " y_val,\n", - " x_test,\n", - " y_test,\n", - ")\n", - "# Scale data\n", - "scale_data = callable()\n", - "numerical_cols = []\n", - "if scale:\n", - " X_tr_normalized, X_val_normalized, X_t_normalized = scale_data(\n", - " numerical_cols,\n", - " X_tr_normalized,\n", - " X_val_normalized,\n", - " X_t_normalized,\n", - " )\n", - "# Process data\n", - "process_data = callable()\n", - "X_tr_final, X_val_final, X_t_final = process_data(\n", - " aggregation_type,\n", - " num_timesteps,\n", - " X_tr_normalized,\n", - " X_val_normalized,\n", - " X_t_normalized,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "cb0a5be5-e6e5-48da-aa6a-94511b907a69", - "metadata": {}, - "source": [ - "## Create Data Streams" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "950b8746-7262-489e-bd31-c4e381900525", - "metadata": {}, - "outputs": [], - "source": [ - "start_date = date(2019, 1, 1)\n", - "end_date = date(2020, 8, 1)\n", - "\n", - "val_ids = list(X_val_normalized.index.get_level_values(0).unique())\n", - "get_streams = callable()\n", - "x_test_stream, y_test_stream, measure_dates_test = get_streams(\n", - " x,\n", - " y,\n", - " admin_data,\n", - " start_date,\n", - " end_date,\n", - " stride=1,\n", - " window=1,\n", - " ids_to_exclude=val_ids,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "9d6b18be-663a-42fc-a4c6-056e2be2fe7a", - "metadata": {}, - "source": [ - "## Dynamic Rolling Window" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a61f657a-4e2b-4933-ad74-c91bd69b9292", - "metadata": {}, - "outputs": [], - "source": [ - "random.seed(1)\n", - "# rolling window parameters\n", - "threshold = 0.05\n", - "num_timesteps = 6\n", - "stat_window = 30\n", - "lookup_window = 0\n", - "stride = 1\n", - "# model parameters\n", - "output_dim = 1\n", - "batch_size = 64\n", - "input_dim = 108\n", - "timesteps = 6\n", - "hidden_dim = 64\n", - "layer_dim = 2\n", - "dropout = 0.2\n", - "n_epochs = 1\n", - "learning_rate = 2e-3\n", - "weight_decay = 1e-6\n", - "last_timestep_only = False\n", - "device = get_device()\n", - "model_params = {\n", - " \"device\": device,\n", - " \"input_dim\": input_dim,\n", - " \"hidden_dim\": hidden_dim,\n", - " \"layer_dim\": layer_dim,\n", - " \"output_dim\": output_dim,\n", - " \"dropout_prob\": dropout,\n", - " \"last_timestep_only\": last_timestep_only,\n", - "}\n", - "# drift detector parameters\n", - "dr_technique = \"BBSDs_trained_LSTM\"\n", - "model_path = os.path.join(os.getcwd(), \"../../saved_models/\" + shift + \"_lstm.pt\")\n", - "md_test = \"MMD\"\n", - "sign_level = 0.05\n", - "sample = 1000\n", - "dataset = \"gemini\"\n", - "context_type = \"rnn\"\n", - "representation = \"rf\"\n", - "\n", - "# retrain parameters\n", - "model_name = \"rnn\"\n", - "retrain = True\n", - "\n", - "# Get shift reductor\n", - "ShiftReductor = callable()\n", - "shift_reductor = ShiftReductor(\n", - " X_tr_final,\n", - " y_tr,\n", - " dr_technique,\n", - " dataset,\n", - " var_ret=0.8,\n", - " model_path=model_path,\n", - ")\n", - "# Get shift detector\n", - "ShiftDetector = callable()\n", - "shift_detector = ShiftDetector(\n", - " dr_technique,\n", - " md_test,\n", - " sign_level,\n", - " shift_reductor,\n", - " sample,\n", - " dataset,\n", - " feats,\n", - " model_path,\n", - " context_type,\n", - " representation,\n", - ")\n", - "\n", - "if model_name == \"rnn\":\n", - " model = get_temporal_model(\"lstm\", model_params).to(device)\n", - " model_path = os.path.join(os.getcwd(), \"../../saved_models/\", shift + \"_lstm.pt\")\n", - " model.load_state_dict(torch.load(model_path))\n", - "\n", - " loss_fn = nn.BCEWithLogitsLoss(reduction=\"none\")\n", - " optimizer = optim.Adagrad(\n", - " model.parameters(),\n", - " lr=learning_rate,\n", - " weight_decay=weight_decay,\n", - " )\n", - " lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=128, gamma=0.5)\n", - " activation = nn.Sigmoid()\n", - " opt = Optimizer(\n", - " model=model,\n", - " loss_fn=loss_fn,\n", - " optimizer=optimizer,\n", - " activation=activation,\n", - " lr_scheduler=lr_scheduler,\n", - " )\n", - "if model_name == \"gbt\":\n", - " with open(model_path, \"rb\") as f:\n", - " model = pickle.load(f)\n", - "\n", - "# ------------------------------------------------------------------\n", - "# low confidence adjusting drift detector - if drift is significant,\n", - "# reference dataset is reset to current time\n", - "# ------------------------------------------------------------------\n", - "\n", - "\n", - "def confidence_rolling_window(\n", - " X_train,\n", - " X_stream,\n", - " y_stream,\n", - " shift_detector,\n", - " sample,\n", - " stat_window,\n", - " lookup_window,\n", - " stride,\n", - " num_timesteps,\n", - " threshold,\n", - " model_name,\n", - " model,\n", - " opt=None,\n", - " X_ref=None,\n", - " retrain=True,\n", - "):\n", - " p_vals = np.asarray([])\n", - " dist_vals = np.asarray([])\n", - " rolling_metrics = []\n", - " run_length = int(stat_window)\n", - "\n", - " i = stat_window\n", - " p_val = 1\n", - " total_alarms = 0\n", - " verbose = 0\n", - "\n", - " if X_ref is not None:\n", - " X_prev = X_ref\n", - " # create val loader\n", - "\n", - " while i + stat_window + lookup_window <= len(X_stream):\n", - " if p_val < threshold:\n", - " if retrain:\n", - " # Get data for updated fit\n", - " X_update = pd.concat(X_stream[max(int(i) - run_length, 0) : int(i)])\n", - " X_update = X_update[~X_update.index.duplicated(keep=\"first\")]\n", - " ind = X_update.index.get_level_values(0).unique()\n", - " X_update = reshape_inputs(X_update, num_timesteps)\n", - "\n", - " y_update = pd.concat(y_stream[max(int(i) - run_length, 0) : int(i)])\n", - " y_update.index = ind\n", - " y_update = y_update[~y_update.index.duplicated(keep=\"first\")].to_numpy()\n", - "\n", - " # Get updated source (validation) data for\n", - " # two-sample test (including data for retraining)\n", - " X_prev = np.concatenate((X_prev, X_update), axis=0)\n", - " tups = [tuple(row) for row in X_prev]\n", - " X_prev = np.unique(tups, axis=0)\n", - " np.random.shuffle(X_prev)\n", - "\n", - " print(\n", - " \"Retrain \",\n", - " model_name,\n", - " \" on: \",\n", - " max(int(i) - run_length, 0),\n", - " \"-\",\n", - " int(i),\n", - " )\n", - "\n", - " if model_name == \"rnn\":\n", - " # create train loader\n", - " update_dataset = get_data(X_update, y_update)\n", - " update_loader = torch.utils.data.DataLoader(\n", - " update_dataset,\n", - " batch_size=1,\n", - " shuffle=False,\n", - " )\n", - "\n", - " retrain_model_path = \"adaptive_window_retrain.model\"\n", - "\n", - " # train\n", - " opt.train(\n", - " update_loader,\n", - " update_loader,\n", - " batch_size=batch_size,\n", - " n_epochs=n_epochs,\n", - " n_features=input_dim,\n", - " timesteps=timesteps,\n", - " model_path=retrain_model_path,\n", - " )\n", - "\n", - " model.load_state_dict(torch.load(retrain_model_path))\n", - " opt.model = model\n", - " shift_detector.model_path = retrain_model_path\n", - "\n", - " elif model_name == \"gbt\":\n", - " X_retrain, y_retrain = None, None\n", - " model = model.fit(\n", - " X_retrain,\n", - " y_retrain,\n", - " xgb_model=model.get_booster(),\n", - " )\n", - "\n", - " else:\n", - " print(\"Invalid Model Name\")\n", - "\n", - " i += stride\n", - "\n", - " if X_ref is None:\n", - " X_prev = pd.concat(\n", - " X_stream[max(int(i) - run_length, 0) : int(i) + stat_window],\n", - " )\n", - " X_prev = X_prev[~X_prev.index.duplicated(keep=\"first\")]\n", - " X_prev = reshape_inputs(X_prev, num_timesteps)\n", - "\n", - " # Get next stream of test data\n", - " X_next = pd.concat(\n", - " X_stream[\n", - " max(int(i) + lookup_window, 0) : int(i) + stat_window + lookup_window\n", - " ],\n", - " )\n", - " X_next = X_next[~X_next.index.duplicated(keep=\"first\")]\n", - " next_ind = X_next.index.get_level_values(0).unique()\n", - " X_next = reshape_inputs(X_next, num_timesteps)\n", - "\n", - " y_next = pd.concat(\n", - " y_stream[\n", - " max(int(i) + lookup_window, 0) : int(i) + stat_window + lookup_window\n", - " ],\n", - " )\n", - " y_next.index = next_ind\n", - " y_next = y_next[~y_next.index.duplicated(keep=\"first\")].to_numpy()\n", - "\n", - " # Ensure next stream of test data is not empty\n", - " if X_next.shape[0] <= 2 or X_prev.shape[0] <= 2:\n", - " break\n", - "\n", - " # Check Performance\n", - " test_dataset = get_data(X_next, y_next)\n", - " test_loader = torch.utils.data.DataLoader(\n", - " test_dataset,\n", - " batch_size=1,\n", - " shuffle=False,\n", - " )\n", - " y_test_labels, y_pred_values, y_pred_labels = opt.evaluate(\n", - " test_loader,\n", - " batch_size=1,\n", - " n_features=input_dim,\n", - " timesteps=num_timesteps,\n", - " )\n", - "\n", - " y_pred_values = y_pred_values[y_test_labels != -1]\n", - " y_pred_labels = y_pred_labels[y_test_labels != -1]\n", - " y_test_labels = y_test_labels[y_test_labels != -1]\n", - "\n", - " pred_metrics = print_metrics_binary(\n", - " y_test_labels,\n", - " y_pred_values,\n", - " y_pred_labels,\n", - " verbose=verbose,\n", - " )\n", - " rolling_metrics.append(\n", - " pd.DataFrame(pred_metrics.values(), index=pred_metrics.keys()).T,\n", - " )\n", - "\n", - " # Run distribution shift check here\n", - " (p_val, dist, val_acc, te_acc) = shift_detector.detect_data_shift(\n", - " X_train,\n", - " X_prev[:1000, :],\n", - " X_next[:sample, :],\n", - " )\n", - "\n", - " # print(max(int(i)-run_length,0),\"-\",\n", - " # \"-\",int(i)+stat_window+lookup_window,\"\\tP-Value: \",p_val)\n", - "\n", - " dist_vals = np.concatenate((dist_vals, np.repeat(dist, 1)))\n", - " p_vals = np.concatenate((p_vals, np.repeat(p_val, 1)))\n", - "\n", - " if p_val >= threshold:\n", - " run_length += stride\n", - " i += stride\n", - " else:\n", - " run_length = stat_window\n", - " total_alarms += 1\n", - "\n", - " rolling_metrics = pd.concat(rolling_metrics).reset_index(drop=True)\n", - "\n", - " return dist_vals, p_vals, rolling_metrics, total_alarms\n", - "\n", - "\n", - "dist_test, pvals_test, performance_metrics, total_alarms = confidence_rolling_window(\n", - " X_tr_final,\n", - " x_test_stream,\n", - " y_test_stream,\n", - " shift_detector,\n", - " sample,\n", - " stat_window,\n", - " lookup_window,\n", - " stride,\n", - " num_timesteps,\n", - " threshold,\n", - " model_name=model_name,\n", - " model=model,\n", - " opt=opt,\n", - " X_ref=X_val_final,\n", - " retrain=retrain,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9b667946-1a75-4be0-8ccf-df05396b0ed2", - "metadata": {}, - "outputs": [], - "source": [ - "total_alarms" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12cdf7ad-9b65-4370-b18f-f1980a8f2c25", - "metadata": {}, - "outputs": [], - "source": [ - "mean = np.mean(pvals_test[pvals_test < 0.05])\n", - "ci = st.t.interval(\n", - " 0.95,\n", - " len(pvals_test[pvals_test < 0.05]) - 1,\n", - " loc=np.mean(pvals_test[pvals_test < 0.05]),\n", - " scale=st.sem(pvals_test[pvals_test < 0.05]),\n", - ")\n", - "print(mean, ci)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "abf9d8fd-0dd6-4645-9a7f-1b7060109d3f", - "metadata": {}, - "outputs": [], - "source": [ - "end = performance_metrics.shape[0]\n", - "threshold = 0.05\n", - "measure_dates_test_adjust = [\n", - " (\n", - " datetime.datetime.strptime(date, \"%Y-%m-%d\")\n", - " + datetime.timedelta(days=lookup_window + stat_window)\n", - " ).strftime(\"%Y-%m-%d\")\n", - " for date in measure_dates_test\n", - "]\n", - "fig, (ax1, ax2, ax3, ax4, ax5, ax6) = plt.subplots(6, 1, figsize=(22, 12))\n", - "results = pd.DataFrame(\n", - " {\n", - " \"dates\": measure_dates_test_adjust[0:end],\n", - " \"pval\": pvals_test[0:end],\n", - " \"dist\": dist_test[0:end],\n", - " \"detection\": np.where(pvals_test[0:end] < threshold, 1, 0),\n", - " },\n", - ")\n", - "results = pd.concat([results, performance_metrics], axis=1)\n", - "results.to_pickle(\n", - " os.path.join(\n", - " PATH,\n", - " shift,\n", - " shift + \"_\" + dr_technique + \"_\" + md_test + \"_results.pkl\",\n", - " ),\n", - ")\n", - "start = 0\n", - "end = performance_metrics.shape[0] - 1\n", - "cmap = ListedColormap([\"lightgrey\", \"red\"])\n", - "ax1.plot(\n", - " results[\"dates\"],\n", - " results[\"pval\"],\n", - " \".-\",\n", - " color=\"red\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax1.set_xlim(results[\"dates\"][start], results[\"dates\"][end])\n", - "ax1.axhline(y=threshold, color=\"dimgrey\", linestyle=\"--\")\n", - "ax1.set_ylabel(\"P-Values\", fontsize=16)\n", - "ax1.set_xticklabels([])\n", - "ax1.pcolorfast(\n", - " ax1.get_xlim(),\n", - " ax1.get_ylim(),\n", - " results[\"detection\"].values[np.newaxis],\n", - " cmap=cmap,\n", - " alpha=0.4,\n", - ")\n", - "\n", - "ax2.plot(\n", - " results[\"dates\"],\n", - " results[\"dist\"],\n", - " \".-\",\n", - " color=\"red\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax2.set_xlim(results[\"dates\"][start], results[\"dates\"][end])\n", - "ax2.set_ylabel(\"Distance\", fontsize=16)\n", - "ax2.axhline(y=np.mean(results[\"dist\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax2.set_xticklabels([])\n", - "ax2.pcolorfast(\n", - " ax2.get_xlim(),\n", - " ax2.get_ylim(),\n", - " results[\"detection\"].values[np.newaxis],\n", - " cmap=cmap,\n", - " alpha=0.4,\n", - ")\n", - "\n", - "ax3.plot(\n", - " results[\"dates\"],\n", - " results[\"auroc\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax3.set_xlim(results[\"dates\"][start], results[\"dates\"][end])\n", - "ax3.set_ylabel(\"AUROC\", fontsize=16)\n", - "ax3.axhline(y=np.mean(results[\"auroc\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax3.set_xticklabels([])\n", - "ax3.pcolorfast(\n", - " ax3.get_xlim(),\n", - " ax3.get_ylim(),\n", - " results[\"detection\"].values[np.newaxis],\n", - " cmap=cmap,\n", - " alpha=0.4,\n", - ")\n", - "\n", - "ax4.plot(\n", - " results[\"dates\"],\n", - " results[\"auprc\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax4.set_xlim(results[\"dates\"][start], results[\"dates\"][end])\n", - "ax4.set_ylabel(\"AUPRC\", fontsize=16)\n", - "ax4.axhline(y=np.mean(results[\"auprc\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax4.set_xticklabels([])\n", - "ax4.pcolorfast(\n", - " ax4.get_xlim(),\n", - " ax4.get_ylim(),\n", - " results[\"detection\"].values[np.newaxis],\n", - " cmap=cmap,\n", - " alpha=0.4,\n", - ")\n", - "\n", - "ax5.plot(\n", - " results[\"dates\"],\n", - " results[\"prec1\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax5.set_xlim(results[\"dates\"][start], results[\"dates\"][end])\n", - "ax5.set_ylabel(\"PPV\", fontsize=16)\n", - "ax5.axhline(y=np.mean(results[\"prec1\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax5.set_xticklabels([])\n", - "ax5.pcolorfast(\n", - " ax5.get_xlim(),\n", - " ax5.get_ylim(),\n", - " results[\"detection\"].values[np.newaxis],\n", - " cmap=cmap,\n", - " alpha=0.4,\n", - ")\n", - "\n", - "ax6.plot(\n", - " results[\"dates\"],\n", - " results[\"rec1\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax6.set_xlim(results[\"dates\"][start], results[\"dates\"][end])\n", - "ax6.set_ylabel(\"Sensitivity\", fontsize=16)\n", - "ax6.set_xlabel(\"time (s)\", fontsize=16)\n", - "ax6.axhline(y=np.mean(results[\"rec1\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax6.tick_params(axis=\"x\", labelrotation=45)\n", - "ax6.pcolorfast(\n", - " ax6.get_xlim(),\n", - " ax6.get_ylim(),\n", - " results[\"detection\"].values[np.newaxis],\n", - " cmap=cmap,\n", - " alpha=0.4,\n", - ")\n", - "\n", - "for index, label in enumerate(ax6.xaxis.get_ticklabels()):\n", - " if index % 28 != 0:\n", - " label.set_visible(False)\n", - "\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "cyclops-KKtuQLwg-py3.9", - "language": "python", - "name": "cyclops-kktuqlwg-py3.9" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/nbs/monitor/retraining/cumulative.ipynb b/nbs/monitor/retraining/cumulative.ipynb deleted file mode 100644 index dea814895..000000000 --- a/nbs/monitor/retraining/cumulative.ipynb +++ /dev/null @@ -1,504 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "2820e910-d3f5-47d0-a600-f688df3daada", - "metadata": {}, - "source": [ - "### Retraining, using all of the encounters to-date" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d995dcd3-fb7a-4398-b8b8-984c2ddc78be", - "metadata": {}, - "outputs": [], - "source": [ - "import datetime\n", - "import os\n", - "import random\n", - "from datetime import date\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import scipy.stats as st\n", - "from baseline_models.temporal.pytorch.optimizer import Optimizer\n", - "from baseline_models.temporal.pytorch.utils import get_device, load_ckp\n", - "from drift_detector.detector import Detector\n", - "from drift_detector.reductor import Reductor\n", - "from drift_detector.tester import TSTester\n", - "from drift_detector.utils import get_serving_data, get_temporal_model\n", - "from gemini.query import get_gemini_data\n", - "from gemini.utils import get_label, import_dataset_hospital, normalize, process, scale\n", - "from matplotlib.colors import ListedColormap\n", - "from retrainers.cumulative import CumulativeRetrainer\n", - "from torch import nn, optim" - ] - }, - { - "cell_type": "markdown", - "id": "95b0bc00-3c3e-4846-a89f-3d5ad2c15b8c", - "metadata": {}, - "source": [ - "## Get parameters" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17cff9cf-c682-4973-ab27-1161b70d516a", - "metadata": {}, - "outputs": [], - "source": [ - "PATH = \"/mnt/nfs/project/delirium/drift_exp/JULY-04-2022/\"\n", - "TIMESTEPS = 6\n", - "AGGREGATION_TYPE = \"time\"\n", - "HOSPITALS = [\"SMH\", \"MSH\", \"THPC\", \"THPM\", \"UHNTG\", \"UHNTW\", \"PMH\", \"SBK\"]\n", - "ACADEMIC = [\"MSH\", \"PMH\", \"SMH\", \"UHNTW\", \"UHNTG\", \"PMH\", \"SBK\"]\n", - "COMMUNITY = [\"THPC\", \"THPM\"]\n", - "OUTCOME = \"mortality\"\n", - "THRESHOLD = 0.05\n", - "NUM_TIMESTEPS = 6\n", - "STAT_WINDOW = 30\n", - "LOOKUP_WINDOW = 0\n", - "STRIDE = 1\n", - "\n", - "SHIFT = input(\"Select experiment: \") # hospital_type\n", - "MODEL_PATH = os.path.join(PATH, \"saved_models\", SHIFT + \"_lstm.pt\")\n", - "\n", - "if SHIFT == \"simulated_deployment\":\n", - " exp_params = {\n", - " \"source\": [datetime.date(2015, 1, 1), datetime.date(2019, 1, 1)],\n", - " \"target\": [datetime.date(2019, 1, 1), datetime.date(2020, 8, 1)],\n", - " \"shift_type\": \"source_target\",\n", - " }\n", - "\n", - "if SHIFT == \"covid\":\n", - " exp_params = {\n", - " \"source\": [datetime.date(2019, 1, 1), datetime.date(2020, 2, 1)],\n", - " \"target\": [datetime.date(2020, 3, 1), datetime.date(2020, 8, 1)],\n", - " \"shift_type\": \"time\",\n", - " }\n", - "\n", - "if SHIFT == \"seasonal_summer\":\n", - " exp_params = {\n", - " \"source\": [1, 2, 3, 4, 5, 10, 11, 12],\n", - " \"target\": [6, 7, 8, 9],\n", - " \"shift_type\": \"month\",\n", - " }\n", - "\n", - "if SHIFT == \"seasonal_winter\":\n", - " exp_params = {\n", - " \"source\": [3, 4, 5, 6, 7, 8, 9, 10],\n", - " \"target\": [11, 12, 1, 2],\n", - " \"shift_type\": \"month\",\n", - " }\n", - "\n", - "if SHIFT == \"hosp_type_academic\":\n", - " exp_params = {\n", - " \"source\": ACADEMIC,\n", - " \"target\": COMMUNITY,\n", - " \"shift_type\": \"hospital_type\",\n", - " }\n", - "\n", - "if SHIFT == \"hosp_type_community\":\n", - " exp_params = {\n", - " \"source\": COMMUNITY,\n", - " \"target\": ACADEMIC,\n", - " \"shift_type\": \"hospital_type\",\n", - " }" - ] - }, - { - "cell_type": "markdown", - "id": "a468f8d3-19f5-44ab-843f-917a61dc45de", - "metadata": {}, - "source": [ - "## Get data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "93113177-c6d6-44af-9f74-0c043c781416", - "metadata": {}, - "outputs": [], - "source": [ - "admin_data, x, y = get_gemini_data(PATH)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2094d958-185d-42f3-b3d1-eea9b91341e0", - "metadata": {}, - "outputs": [], - "source": [ - "random.seed(1)\n", - "\n", - "(X_tr, y_tr), (X_val, y_val), (X_t, y_t), feats, admin_data = import_dataset_hospital(\n", - " admin_data,\n", - " x,\n", - " y,\n", - " SHIFT,\n", - " OUTCOME,\n", - " HOSPITALS,\n", - ")\n", - "\n", - "# Normalize data\n", - "X_tr_normalized = normalize(admin_data, X_tr, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_val_normalized = normalize(admin_data, X_val, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_normalized = normalize(admin_data, X_t, AGGREGATION_TYPE, TIMESTEPS)\n", - "\n", - "if AGGREGATION_TYPE != \"time\":\n", - " y_tr = get_label(admin_data, X_tr, OUTCOME)\n", - " y_val = get_label(admin_data, X_val, OUTCOME)\n", - " y_t = get_label(admin_data, X_t, OUTCOME)\n", - "\n", - "# Scale data\n", - "X_tr_scaled = scale(X_tr_normalized)\n", - "X_val_scaled = scale(X_val_normalized)\n", - "X_t_scaled = scale(X_t_normalized)\n", - "\n", - "# Process data\n", - "X_tr_final = process(X_tr_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_val_final = process(X_val_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_final = process(X_t_scaled, AGGREGATION_TYPE, TIMESTEPS)" - ] - }, - { - "cell_type": "markdown", - "id": "cb0a5be5-e6e5-48da-aa6a-94511b907a69", - "metadata": {}, - "source": [ - "## Create data streams" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "950b8746-7262-489e-bd31-c4e381900525", - "metadata": {}, - "outputs": [], - "source": [ - "START_DATE = date(2019, 1, 1)\n", - "END_DATE = date(2020, 8, 1)\n", - "\n", - "print(\"Get target data streams...\")\n", - "data_streams = get_serving_data(\n", - " x,\n", - " y,\n", - " admin_data,\n", - " START_DATE,\n", - " END_DATE,\n", - " stride=1,\n", - " window=1,\n", - " encounter_id=\"encounter_id\",\n", - " admit_timestamp=\"admit_timestamp\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "16c63c96-d326-4611-bc6b-de156c1f3f21", - "metadata": {}, - "source": [ - "## Get prediction model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "485aee51-3a9e-4e03-ad94-7e9f1b1251d6", - "metadata": {}, - "outputs": [], - "source": [ - "retrain = \"update\"\n", - "model_name = \"lstm\"\n", - "output_dim = 1\n", - "input_dim = 108\n", - "hidden_dim = 64\n", - "layer_dim = 2\n", - "dropout = 0.2\n", - "last_timestep_only = False\n", - "device = get_device()\n", - "\n", - "model_params = {\n", - " \"device\": device,\n", - " \"input_dim\": input_dim,\n", - " \"hidden_dim\": hidden_dim,\n", - " \"layer_dim\": layer_dim,\n", - " \"output_dim\": output_dim,\n", - " \"dropout_prob\": dropout,\n", - " \"last_timestep_only\": last_timestep_only,\n", - "}\n", - "\n", - "model = get_temporal_model(model_name, model_params).to(device)\n", - "model, optimizer, n_epochs = load_ckp(MODEL_PATH, model)\n", - "\n", - "# Load model and trainer\n", - "if model_name in [\"rnn\", \"gru\", \"lstm\"]:\n", - " model = get_temporal_model(\"lstm\", model_params).to(device)\n", - "\n", - " if retrain == \"update\":\n", - " checkpoint_fpath = os.path.join(PATH, \"saved_models\", SHIFT + \"_lstm.pt\")\n", - " model, opt, n_epochs = load_ckp(checkpoint_fpath, model)\n", - " n_epochs = 1\n", - " else:\n", - " n_epochs = 64\n", - " learning_rate = 2e-3\n", - " weight_decay = 1e-6\n", - " loss_fn = nn.BCEWithLogitsLoss(reduction=\"none\")\n", - " optimizer = optim.Adagrad(\n", - " model.parameters(),\n", - " lr=learning_rate,\n", - " weight_decay=weight_decay,\n", - " )\n", - " lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=128, gamma=0.5)\n", - " activation = nn.Sigmoid()\n", - " opt = Optimizer(\n", - " model=model,\n", - " loss_fn=loss_fn,\n", - " optimizer=optimizer,\n", - " activation=activation,\n", - " lr_scheduler=lr_scheduler,\n", - " )\n", - "# with open(model_path, \"rb\") as f:\n", - "else:\n", - " print(\"Unsupported model\")" - ] - }, - { - "cell_type": "markdown", - "id": "ea405f03-59e3-4e6a-a604-6f8b30846860", - "metadata": {}, - "source": [ - "## Get shift detector" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b5268d3d-a89b-409a-abd9-5e0a7c91ff08", - "metadata": {}, - "outputs": [], - "source": [ - "DR_TECHNIQUE = \"BBSDs_trained_LSTM\"\n", - "MD_TEST = \"mmd\"\n", - "SAMPLE = 1000\n", - "CONTEXT_TYPE = \"lstm\"\n", - "PROJ_TYPE = \"lstm\"\n", - "\n", - "print(\"Get Shift Reductor...\")\n", - "reductor = Reductor(\n", - " dr_method=DR_TECHNIQUE,\n", - " model_path=MODEL_PATH,\n", - " n_features=len(feats),\n", - " var_ret=0.8,\n", - ")\n", - "\n", - "print(\"Get Shift Tester...\")\n", - "tester = TSTester(tester_method=MD_TEST)\n", - "\n", - "print(\"Get Shift Detector...\")\n", - "detector = Detector(\n", - " reductor=reductor,\n", - " tester=tester,\n", - " p_val_threshold=0.05,\n", - ")\n", - "detector.fit(X_tr_final)" - ] - }, - { - "cell_type": "markdown", - "id": "907b2730-50ba-44fb-9df5-8e7fcaa468cd", - "metadata": {}, - "source": [ - "## Retrain" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9267fc0a-796e-42fe-923c-c140d9cb6d98", - "metadata": {}, - "outputs": [], - "source": [ - "retrainer = CumulativeRetrainer(\n", - " shift_detector=detector,\n", - " optimizer=optimizer,\n", - " model=model,\n", - " model_name=model_name,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "277c50c2-8c65-4184-82ca-2dd2d28e0717", - "metadata": {}, - "outputs": [], - "source": [ - "all_runs = []\n", - "for _i in range(0, 5):\n", - " random.seed(1)\n", - "\n", - " run_dict = retrainer.retrain(\n", - " data_streams=data_streams,\n", - " model_path=MODEL_PATH,\n", - " context_type=CONTEXT_TYPE,\n", - " proj_type=PROJ_TYPE,\n", - " )\n", - "\n", - " all_runs.append(run_dict)\n", - " pvals_test = run_dict[\"p_val\"]\n", - " mean = np.mean(pvals_test[pvals_test < 0.05])\n", - " ci = st.t.interval(\n", - " 0.95,\n", - " len(pvals_test[pvals_test < 0.05]) - 1,\n", - " loc=np.mean(pvals_test[pvals_test < 0.05]),\n", - " scale=st.sem(pvals_test[pvals_test < 0.05]),\n", - " )\n", - " total_alarms = pvals_test[pvals_test < 0.05].sum()\n", - " print(total_alarms, \" alarms with avg p-value of \", mean, ci)\n", - " np.save(\n", - " os.path.join(PATH, SHIFT, SHIFT + \"_cumulative_10epochs_retraining_update.npy\"),\n", - " all_runs,\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7a904bb2-edba-4363-a6bd-1774bb9ff8c6", - "metadata": {}, - "outputs": [], - "source": [ - "results = run_dict\n", - "\n", - "p_val_threshold = 0.05\n", - "sig_drift = np.array(results[\"shift_detected\"])[np.newaxis]\n", - "\n", - "fig, (ax1, ax2, ax3, ax4, ax5, ax6) = plt.subplots(6, 1, figsize=(18, 12))\n", - "cmap = ListedColormap([\"lightgrey\", \"red\"])\n", - "ax1.plot(\n", - " results[\"timestamps\"],\n", - " results[\"p_val\"],\n", - " \".-\",\n", - " color=\"red\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax1.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax1.axhline(y=p_val_threshold, color=\"dimgrey\", linestyle=\"--\")\n", - "ax1.set_ylabel(\"P-Values\", fontsize=16)\n", - "ax1.set_xticklabels([])\n", - "ax1.pcolorfast(ax1.get_xlim(), ax1.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax2.plot(\n", - " results[\"timestamps\"],\n", - " results[\"distance\"],\n", - " \".-\",\n", - " color=\"red\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax2.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax2.set_ylabel(\"Distance\", fontsize=16)\n", - "ax2.axhline(y=np.mean(results[\"distance\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax2.set_xticklabels([])\n", - "ax2.pcolorfast(ax2.get_xlim(), ax2.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax3.plot(\n", - " results[\"timestamps\"],\n", - " results[\"auroc\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax3.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax3.set_ylabel(\"AUROC\", fontsize=16)\n", - "ax3.axhline(y=np.mean(results[\"auroc\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax3.set_xticklabels([])\n", - "ax3.pcolorfast(ax3.get_xlim(), ax3.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax4.plot(\n", - " results[\"timestamps\"],\n", - " results[\"auprc\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax4.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax4.set_ylabel(\"AUPRC\", fontsize=16)\n", - "ax4.axhline(y=np.mean(results[\"auprc\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax4.set_xticklabels([])\n", - "ax4.pcolorfast(ax4.get_xlim(), ax4.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax5.plot(\n", - " results[\"timestamps\"],\n", - " results[\"prec1\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax5.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax5.set_ylabel(\"PPV\", fontsize=16)\n", - "ax5.axhline(y=np.mean(results[\"prec1\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax5.set_xticklabels([])\n", - "ax5.pcolorfast(ax5.get_xlim(), ax5.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax6.plot(\n", - " results[\"timestamps\"],\n", - " results[\"rec1\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax6.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax6.set_ylabel(\"Sensitivity\", fontsize=16)\n", - "ax6.set_xlabel(\"time (s)\", fontsize=16)\n", - "ax6.axhline(y=np.mean(results[\"rec1\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax6.tick_params(axis=\"x\", labelrotation=45)\n", - "ax6.pcolorfast(ax6.get_xlim(), ax6.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "for index, label in enumerate(ax6.xaxis.get_ticklabels()):\n", - " if index % 28 != 0:\n", - " label.set_visible(False)\n", - "\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.4" - }, - "vscode": { - "interpreter": { - "hash": "bd2cd438e1c6ddffa3035fc73b17ac5cc0e0ea8897eb8be17cc645c6abf0c8cc" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/nbs/monitor/retraining/local_global.ipynb b/nbs/monitor/retraining/local_global.ipynb deleted file mode 100644 index 762a940b1..000000000 --- a/nbs/monitor/retraining/local_global.ipynb +++ /dev/null @@ -1,240 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "d313c923-4125-4817-a2cc-142c02dbd842", - "metadata": {}, - "source": [ - "### Compare performance of local vs global models" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2a4bb2e4-c369-4a09-b493-e637eedc3e92", - "metadata": {}, - "outputs": [], - "source": [ - "import datetime\n", - "import os\n", - "import random\n", - "from datetime import date\n", - "\n", - "from baseline_models.temporal.pytorch.utils import get_device, load_ckp\n", - "from drift_detector.utils import get_serving_data, get_temporal_model\n", - "from gemini.query import get_gemini_data\n", - "from gemini.utils import get_label, import_dataset_hospital, normalize, process, scale" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5bfe62aa-1159-448f-83ca-4153e460fc5a", - "metadata": {}, - "outputs": [], - "source": [ - "PATH = \"/mnt/nfs/project/delirium/drift_exp/JULY-04-2022/\"\n", - "TIMESTEPS = 6\n", - "AGGREGATION_TYPE = \"time\"\n", - "HOSPITALS = [\"SMH\", \"MSH\", \"THPC\", \"THPM\", \"UHNTG\", \"UHNTW\", \"PMH\", \"SBK\"]\n", - "OUTCOME = \"mortality\"\n", - "THRESHOLD = 0.05\n", - "NUM_TIMESTEPS = 6\n", - "STAT_WINDOW = 30\n", - "LOOKUP_WINDOW = 0\n", - "STRIDE = 1\n", - "\n", - "SHIFT = input(\"Select experiment: \") # hospital_type\n", - "MODEL_PATH = os.path.join(PATH, \"saved_models\", SHIFT + \"_lstm.pt\")\n", - "\n", - "if SHIFT == \"simulated_deployment\":\n", - " exp_params = {\n", - " \"source\": [datetime.date(2015, 1, 1), datetime.date(2019, 1, 1)],\n", - " \"target\": [datetime.date(2019, 1, 1), datetime.date(2020, 8, 1)],\n", - " \"shift_type\": \"source_target\",\n", - " }\n", - "\n", - "if SHIFT == \"covid\":\n", - " exp_params = {\n", - " \"source\": [datetime.date(2019, 1, 1), datetime.date(2020, 2, 1)],\n", - " \"target\": [datetime.date(2020, 3, 1), datetime.date(2020, 8, 1)],\n", - " \"shift_type\": \"time\",\n", - " }\n", - "\n", - "if SHIFT == \"seasonal_summer\":\n", - " exp_params = {\n", - " \"source\": [1, 2, 3, 4, 5, 10, 11, 12],\n", - " \"target\": [6, 7, 8, 9],\n", - " \"shift_type\": \"month\",\n", - " }\n", - "\n", - "if SHIFT == \"seasonal_winter\":\n", - " exp_params = {\n", - " \"source\": [3, 4, 5, 6, 7, 8, 9, 10],\n", - " \"target\": [11, 12, 1, 2],\n", - " \"shift_type\": \"month\",\n", - " }" - ] - }, - { - "cell_type": "markdown", - "id": "f3fb7e9a-b31f-408c-a4d0-dce81ca06caf", - "metadata": {}, - "source": [ - "## Get data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2dd67da8-f0f7-4c2f-8089-c68b88df9155", - "metadata": {}, - "outputs": [], - "source": [ - "admin_data, x, y = get_gemini_data(PATH)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8b29d43a-00c4-45b1-8d09-939c563061ad", - "metadata": {}, - "outputs": [], - "source": [ - "random.seed(1)\n", - "\n", - "(X_tr, y_tr), (X_val, y_val), (X_t, y_t), feats, admin_data = import_dataset_hospital(\n", - " admin_data,\n", - " x,\n", - " y,\n", - " SHIFT,\n", - " OUTCOME,\n", - " HOSPITALS,\n", - ")\n", - "\n", - "# Normalize data\n", - "X_tr_normalized = normalize(admin_data, X_tr, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_val_normalized = normalize(admin_data, X_val, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_normalized = normalize(admin_data, X_t, AGGREGATION_TYPE, TIMESTEPS)\n", - "\n", - "if AGGREGATION_TYPE != \"time\":\n", - " y_tr = get_label(admin_data, X_tr, OUTCOME)\n", - " y_val = get_label(admin_data, X_val, OUTCOME)\n", - " y_t = get_label(admin_data, X_t, OUTCOME)\n", - "\n", - "# Scale data\n", - "X_tr_scaled = scale(X_tr_normalized)\n", - "X_val_scaled = scale(X_val_normalized)\n", - "X_t_scaled = scale(X_t_normalized)\n", - "\n", - "# Process data\n", - "X_tr_final = process(X_tr_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_val_final = process(X_val_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_final = process(X_t_scaled, AGGREGATION_TYPE, TIMESTEPS)" - ] - }, - { - "cell_type": "markdown", - "id": "e4d1cd6d-205d-4718-b3ba-4876a9c0780d", - "metadata": {}, - "source": [ - "## Create data streams" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "eee19975-ca8f-4cf8-a5c5-6460e0042ae8", - "metadata": {}, - "outputs": [], - "source": [ - "START_DATE = date(2019, 1, 1)\n", - "END_DATE = date(2020, 8, 1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "71641296-59a6-4050-8d11-71e0322cbc52", - "metadata": {}, - "outputs": [], - "source": [ - "print(\"Get target data streams...\")\n", - "data_streams = get_serving_data(\n", - " x,\n", - " y,\n", - " admin_data,\n", - " START_DATE,\n", - " END_DATE,\n", - " stride=1,\n", - " window=1,\n", - " encounter_id=\"encounter_id\",\n", - " admit_timestamp=\"admit_timestamp\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d530efba-0fbc-434b-baca-56b1de87a57b", - "metadata": {}, - "outputs": [], - "source": [ - "output_dim = 1\n", - "input_dim = 108\n", - "hidden_dim = 64\n", - "layer_dim = 2\n", - "dropout = 0.2\n", - "last_timestep_only = False\n", - "device = get_device()\n", - "\n", - "model_params = {\n", - " \"device\": device,\n", - " \"input_dim\": input_dim,\n", - " \"hidden_dim\": hidden_dim,\n", - " \"layer_dim\": layer_dim,\n", - " \"output_dim\": output_dim,\n", - " \"dropout_prob\": dropout,\n", - " \"last_timestep_only\": last_timestep_only,\n", - "}\n", - "\n", - "model = get_temporal_model(\"lstm\", model_params).to(device)\n", - "model, optimizer, n_epochs = load_ckp(MODEL_PATH, model)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17cd0c8e-0fb0-4e84-841e-d2df0b486a2c", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.4" - }, - "vscode": { - "interpreter": { - "hash": "bd2cd438e1c6ddffa3035fc73b17ac5cc0e0ea8897eb8be17cc645c6abf0c8cc" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/nbs/monitor/retraining/most_recent.ipynb b/nbs/monitor/retraining/most_recent.ipynb deleted file mode 100644 index e7ee8f7f1..000000000 --- a/nbs/monitor/retraining/most_recent.ipynb +++ /dev/null @@ -1,505 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "fe5c6ded-8970-44f4-b337-7e19909aaed9", - "metadata": {}, - "source": [ - "### Retraining using window of most recent encounters" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "11504cf1-e260-41ab-8854-79df66acd092", - "metadata": {}, - "outputs": [], - "source": [ - "import datetime\n", - "import os\n", - "import random\n", - "from datetime import date\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import scipy.stats as st\n", - "from baseline_models.temporal.pytorch.optimizer import Optimizer\n", - "from baseline_models.temporal.pytorch.utils import get_device, load_ckp\n", - "from drift_detector.detector import Detector\n", - "from drift_detector.reductor import Reductor\n", - "from drift_detector.tester import TSTester\n", - "from drift_detector.utils import get_serving_data, get_temporal_model\n", - "from gemini.query import get_gemini_data\n", - "from gemini.utils import get_label, import_dataset_hospital, normalize, process, scale\n", - "from matplotlib.colors import ListedColormap\n", - "from retrainers.mostrecent import MostRecentRetrainer\n", - "from torch import nn, optim" - ] - }, - { - "cell_type": "markdown", - "id": "52a87aba-7eb0-436f-8304-56f8b8feab82", - "metadata": {}, - "source": [ - "## Get parameters" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17cff9cf-c682-4973-ab27-1161b70d516a", - "metadata": {}, - "outputs": [], - "source": [ - "PATH = \"/mnt/nfs/project/delirium/drift_exp/JULY-04-2022/\"\n", - "TIMESTEPS = 6\n", - "AGGREGATION_TYPE = \"time\"\n", - "HOSPITALS = [\"SMH\", \"MSH\", \"THPC\", \"THPM\", \"UHNTG\", \"UHNTW\", \"PMH\", \"SBK\"]\n", - "ACADEMIC = [\"MSH\", \"PMH\", \"SMH\", \"UHNTW\", \"UHNTG\", \"PMH\", \"SBK\"]\n", - "COMMUNITY = [\"THPC\", \"THPM\"]\n", - "OUTCOME = \"mortality\"\n", - "THRESHOLD = 0.05\n", - "NUM_TIMESTEPS = 6\n", - "STAT_WINDOW = 30\n", - "LOOKUP_WINDOW = 0\n", - "STRIDE = 1\n", - "\n", - "SHIFT = input(\"Select experiment: \") # hospital_type\n", - "MODEL_PATH = os.path.join(PATH, \"saved_models\", SHIFT + \"_lstm.pt\")\n", - "\n", - "if SHIFT == \"simulated_deployment\":\n", - " exp_params = {\n", - " \"source\": [datetime.date(2015, 1, 1), datetime.date(2019, 1, 1)],\n", - " \"target\": [datetime.date(2019, 1, 1), datetime.date(2020, 8, 1)],\n", - " \"shift_type\": \"source_target\",\n", - " }\n", - "\n", - "if SHIFT == \"covid\":\n", - " exp_params = {\n", - " \"source\": [datetime.date(2019, 1, 1), datetime.date(2020, 2, 1)],\n", - " \"target\": [datetime.date(2020, 3, 1), datetime.date(2020, 8, 1)],\n", - " \"shift_type\": \"time\",\n", - " }\n", - "\n", - "if SHIFT == \"seasonal_summer\":\n", - " exp_params = {\n", - " \"source\": [1, 2, 3, 4, 5, 10, 11, 12],\n", - " \"target\": [6, 7, 8, 9],\n", - " \"shift_type\": \"month\",\n", - " }\n", - "\n", - "if SHIFT == \"seasonal_winter\":\n", - " exp_params = {\n", - " \"source\": [3, 4, 5, 6, 7, 8, 9, 10],\n", - " \"target\": [11, 12, 1, 2],\n", - " \"shift_type\": \"month\",\n", - " }\n", - "\n", - "if SHIFT == \"hosp_type_academic\":\n", - " exp_params = {\n", - " \"source\": ACADEMIC,\n", - " \"target\": COMMUNITY,\n", - " \"shift_type\": \"hospital_type\",\n", - " }\n", - "\n", - "if SHIFT == \"hosp_type_community\":\n", - " exp_params = {\n", - " \"source\": COMMUNITY,\n", - " \"target\": ACADEMIC,\n", - " \"shift_type\": \"hospital_type\",\n", - " }" - ] - }, - { - "cell_type": "markdown", - "id": "92705618-9594-4c4d-9b8f-80c1ea8934d2", - "metadata": {}, - "source": [ - "## Get data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "93113177-c6d6-44af-9f74-0c043c781416", - "metadata": {}, - "outputs": [], - "source": [ - "admin_data, x, y = get_gemini_data(PATH)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "07d91c48-3b72-4510-bfbe-82f70193090c", - "metadata": {}, - "outputs": [], - "source": [ - "random.seed(1)\n", - "\n", - "(X_tr, y_tr), (X_val, y_val), (X_t, y_t), feats, admin_data = import_dataset_hospital(\n", - " admin_data,\n", - " x,\n", - " y,\n", - " SHIFT,\n", - " OUTCOME,\n", - " HOSPITALS,\n", - ")\n", - "\n", - "# Normalize data\n", - "X_tr_normalized = normalize(admin_data, X_tr, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_val_normalized = normalize(admin_data, X_val, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_normalized = normalize(admin_data, X_t, AGGREGATION_TYPE, TIMESTEPS)\n", - "\n", - "if AGGREGATION_TYPE != \"time\":\n", - " y_tr = get_label(admin_data, X_tr, OUTCOME)\n", - " y_val = get_label(admin_data, X_val, OUTCOME)\n", - " y_t = get_label(admin_data, X_t, OUTCOME)\n", - "\n", - "# Scale data\n", - "X_tr_scaled = scale(X_tr_normalized)\n", - "X_val_scaled = scale(X_val_normalized)\n", - "X_t_scaled = scale(X_t_normalized)\n", - "\n", - "# Process data\n", - "X_tr_final = process(X_tr_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_val_final = process(X_val_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_final = process(X_t_scaled, AGGREGATION_TYPE, TIMESTEPS)" - ] - }, - { - "cell_type": "markdown", - "id": "cb0a5be5-e6e5-48da-aa6a-94511b907a69", - "metadata": {}, - "source": [ - "## Create data streams" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "950b8746-7262-489e-bd31-c4e381900525", - "metadata": {}, - "outputs": [], - "source": [ - "START_DATE = date(2019, 1, 1)\n", - "END_DATE = date(2020, 8, 1)\n", - "\n", - "print(\"Get target data streams...\")\n", - "data_streams = get_serving_data(\n", - " x,\n", - " y,\n", - " admin_data,\n", - " START_DATE,\n", - " END_DATE,\n", - " stride=1,\n", - " window=1,\n", - " encounter_id=\"encounter_id\",\n", - " admit_timestamp=\"admit_timestamp\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "8fa8da8d-d6ca-4fac-bf81-48075c02cb9b", - "metadata": {}, - "source": [ - "## Get prediction model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "60d3fa3b-2fd4-41e6-9177-866f320d418a", - "metadata": {}, - "outputs": [], - "source": [ - "retrain = \"update\"\n", - "model_name = \"lstm\"\n", - "output_dim = 1\n", - "input_dim = 108\n", - "hidden_dim = 64\n", - "layer_dim = 2\n", - "dropout = 0.2\n", - "last_timestep_only = False\n", - "device = get_device()\n", - "\n", - "model_params = {\n", - " \"device\": device,\n", - " \"input_dim\": input_dim,\n", - " \"hidden_dim\": hidden_dim,\n", - " \"layer_dim\": layer_dim,\n", - " \"output_dim\": output_dim,\n", - " \"dropout_prob\": dropout,\n", - " \"last_timestep_only\": last_timestep_only,\n", - "}\n", - "\n", - "model = get_temporal_model(model_name, model_params).to(device)\n", - "model, optimizer, n_epochs = load_ckp(MODEL_PATH, model)\n", - "\n", - "# Load model and trainer\n", - "if model_name in [\"rnn\", \"gru\", \"lstm\"]:\n", - " model = get_temporal_model(\"lstm\", model_params).to(device)\n", - "\n", - " if retrain == \"update\":\n", - " checkpoint_fpath = os.path.join(PATH, \"saved_models\", SHIFT + \"_lstm.pt\")\n", - " model, opt, n_epochs = load_ckp(checkpoint_fpath, model)\n", - " n_epochs = 1\n", - " else:\n", - " n_epochs = 64\n", - " learning_rate = 2e-3\n", - " weight_decay = 1e-6\n", - " loss_fn = nn.BCEWithLogitsLoss(reduction=\"none\")\n", - " optimizer = optim.Adagrad(\n", - " model.parameters(),\n", - " lr=learning_rate,\n", - " weight_decay=weight_decay,\n", - " )\n", - " lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=128, gamma=0.5)\n", - " activation = nn.Sigmoid()\n", - " opt = Optimizer(\n", - " model=model,\n", - " loss_fn=loss_fn,\n", - " optimizer=optimizer,\n", - " activation=activation,\n", - " lr_scheduler=lr_scheduler,\n", - " )\n", - "# with open(model_path, \"rb\") as f:\n", - "else:\n", - " print(\"Unsupported model\")" - ] - }, - { - "cell_type": "markdown", - "id": "2630e452-160d-415c-bb39-a2c4e76b2ed6", - "metadata": {}, - "source": [ - "## Get shift detector" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aa5f1a4e-2925-4049-a10b-3ab4d14a8b16", - "metadata": {}, - "outputs": [], - "source": [ - "DR_TECHNIQUE = \"BBSDs_trained_LSTM\"\n", - "MD_TEST = \"mmd\"\n", - "SAMPLE = 1000\n", - "CONTEXT_TYPE = \"lstm\"\n", - "PROJ_TYPE = \"lstm\"\n", - "\n", - "print(\"Get Shift Reductor...\")\n", - "reductor = Reductor(\n", - " dr_method=DR_TECHNIQUE,\n", - " model_path=MODEL_PATH,\n", - " n_features=len(feats),\n", - " var_ret=0.8,\n", - ")\n", - "\n", - "print(\"Get Shift Tester...\")\n", - "tester = TSTester(tester_method=MD_TEST)\n", - "\n", - "print(\"Get Shift Detector...\")\n", - "detector = Detector(\n", - " reductor=reductor,\n", - " tester=tester,\n", - " p_val_threshold=0.05,\n", - ")\n", - "detector.fit(X_tr_final)" - ] - }, - { - "cell_type": "markdown", - "id": "9d6b18be-663a-42fc-a4c6-056e2be2fe7a", - "metadata": {}, - "source": [ - "## Retrain" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c2cde887-4d11-46d9-bcd0-5f15e4dc5494", - "metadata": {}, - "outputs": [], - "source": [ - "retrainer = MostRecentRetrainer(\n", - " shift_detector=detector,\n", - " optimizer=optimizer,\n", - " model=model,\n", - " model_name=model_name,\n", - " verbose=True,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "40f93059-4482-4d79-a94d-5bb1a12e2358", - "metadata": {}, - "outputs": [], - "source": [ - "all_runs = []\n", - "for _i in range(0, 5):\n", - " random.seed(1)\n", - "\n", - " run_dict = retrainer.retrain(\n", - " data_streams=data_streams,\n", - " model_path=MODEL_PATH,\n", - " context_type=CONTEXT_TYPE,\n", - " proj_type=PROJ_TYPE,\n", - " )\n", - "\n", - " all_runs.append(run_dict)\n", - " pvals_test = run_dict[\"pval\"]\n", - "\n", - " mean = np.mean(pvals_test[pvals_test < 0.05])\n", - " ci = st.t.interval(\n", - " 0.95,\n", - " len(pvals_test[pvals_test < 0.05]) - 1,\n", - " loc=np.mean(pvals_test[pvals_test < 0.05]),\n", - " scale=st.sem(pvals_test[pvals_test < 0.05]),\n", - " )\n", - " print(sum(pvals_test[pvals_test < 0.05]), \" alarms with avg p-value of \", mean, ci)\n", - " np.save(\n", - " os.path.join(PATH, SHIFT, SHIFT + \"_cumulative_10epochs_retraining_update.npy\"),\n", - " all_runs,\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "abf9d8fd-0dd6-4645-9a7f-1b7060109d3f", - "metadata": {}, - "outputs": [], - "source": [ - "results = run_dict\n", - "\n", - "p_val_threshold = 0.05\n", - "sig_drift = np.array(results[\"shift_detected\"])[np.newaxis]\n", - "\n", - "fig, (ax1, ax2, ax3, ax4, ax5, ax6) = plt.subplots(6, 1, figsize=(18, 12))\n", - "cmap = ListedColormap([\"lightgrey\", \"red\"])\n", - "ax1.plot(\n", - " results[\"timestamps\"],\n", - " results[\"p_val\"],\n", - " \".-\",\n", - " color=\"red\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax1.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax1.axhline(y=p_val_threshold, color=\"dimgrey\", linestyle=\"--\")\n", - "ax1.set_ylabel(\"P-Values\", fontsize=16)\n", - "ax1.set_xticklabels([])\n", - "ax1.pcolorfast(ax1.get_xlim(), ax1.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax2.plot(\n", - " results[\"timestamps\"],\n", - " results[\"distance\"],\n", - " \".-\",\n", - " color=\"red\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax2.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax2.set_ylabel(\"Distance\", fontsize=16)\n", - "ax2.axhline(y=np.mean(results[\"distance\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax2.set_xticklabels([])\n", - "ax2.pcolorfast(ax2.get_xlim(), ax2.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax3.plot(\n", - " results[\"timestamps\"],\n", - " results[\"auroc\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax3.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax3.set_ylabel(\"AUROC\", fontsize=16)\n", - "ax3.axhline(y=np.mean(results[\"auroc\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax3.set_xticklabels([])\n", - "ax3.pcolorfast(ax3.get_xlim(), ax3.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax4.plot(\n", - " results[\"timestamps\"],\n", - " results[\"auprc\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax4.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax4.set_ylabel(\"AUPRC\", fontsize=16)\n", - "ax4.axhline(y=np.mean(results[\"auprc\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax4.set_xticklabels([])\n", - "ax4.pcolorfast(ax4.get_xlim(), ax4.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax5.plot(\n", - " results[\"timestamps\"],\n", - " results[\"prec1\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax5.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax5.set_ylabel(\"PPV\", fontsize=16)\n", - "ax5.axhline(y=np.mean(results[\"prec1\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax5.set_xticklabels([])\n", - "ax5.pcolorfast(ax5.get_xlim(), ax5.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax6.plot(\n", - " results[\"timestamps\"],\n", - " results[\"rec1\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax6.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax6.set_ylabel(\"Sensitivity\", fontsize=16)\n", - "ax6.set_xlabel(\"time (s)\", fontsize=16)\n", - "ax6.axhline(y=np.mean(results[\"rec1\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax6.tick_params(axis=\"x\", labelrotation=45)\n", - "ax6.pcolorfast(ax6.get_xlim(), ax6.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "for index, label in enumerate(ax6.xaxis.get_ticklabels()):\n", - " if index % 28 != 0:\n", - " label.set_visible(False)\n", - "\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.4" - }, - "vscode": { - "interpreter": { - "hash": "bd2cd438e1c6ddffa3035fc73b17ac5cc0e0ea8897eb8be17cc645c6abf0c8cc" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/nbs/monitor/rolling_window/hosp_type.ipynb b/nbs/monitor/rolling_window/hosp_type.ipynb deleted file mode 100644 index 227582335..000000000 --- a/nbs/monitor/rolling_window/hosp_type.ipynb +++ /dev/null @@ -1,412 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "9a0204ce-4c3e-44d4-ba4e-87700c720acf", - "metadata": {}, - "outputs": [], - "source": [ - "import datetime\n", - "import os\n", - "import random\n", - "from datetime import date\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "from baseline_models.temporal.pytorch.utils import get_device, load_ckp\n", - "from drift_detector.detector import Detector\n", - "from drift_detector.reductor import Reductor\n", - "from drift_detector.rolling_window import RollingWindow\n", - "from drift_detector.tester import TSTester\n", - "from drift_detector.utils import get_serving_data, get_temporal_model\n", - "from gemini.query import get_gemini_data\n", - "from gemini.utils import get_label, import_dataset_hospital, normalize, process, scale\n", - "from matplotlib.colors import ListedColormap" - ] - }, - { - "cell_type": "markdown", - "id": "07278f40-75cf-4008-81c4-72741c6a0c39", - "metadata": {}, - "source": [ - "## Config parameters" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9dc229b2-82b7-45f6-808a-f03b438f09ba", - "metadata": {}, - "outputs": [], - "source": [ - "PATH = \"/mnt/nfs/project/delirium/drift_exp/JULY-04-2022/\"\n", - "TIMESTEPS = 6\n", - "AGGREGATION_TYPE = \"time\"\n", - "HOSPITALS = [\"SMH\", \"MSH\", \"THPC\", \"THPM\", \"UHNTG\", \"UHNTW\", \"PMH\", \"SBK\"]\n", - "ACADEMIC = [\"MSH\", \"PMH\", \"SMH\", \"UHNTW\", \"UHNTG\", \"PMH\", \"SBK\"]\n", - "COMMUNITY = [\"THPC\", \"THPM\"]\n", - "OUTCOME = \"mortality\"\n", - "THRESHOLD = 0.05\n", - "NUM_TIMESTEPS = 6\n", - "STAT_WINDOW = 30\n", - "LOOKUP_WINDOW = 0\n", - "STRIDE = 1\n", - "\n", - "SHIFT = input(\"Select experiment: \") # hospital_type\n", - "MODEL_PATH = os.path.join(PATH, \"saved_models\", SHIFT + \"_lstm.pt\")\n", - "\n", - "if SHIFT == \"hosp_type_academic\":\n", - " exp_params = {\n", - " \"source\": ACADEMIC,\n", - " \"target\": COMMUNITY,\n", - " \"shift_type\": \"hospital_type\",\n", - " }\n", - "\n", - "if SHIFT == \"hosp_type_community\":\n", - " exp_params = {\n", - " \"source\": COMMUNITY,\n", - " \"target\": ACADEMIC,\n", - " \"shift_type\": \"hospital_type\",\n", - " }" - ] - }, - { - "cell_type": "markdown", - "id": "237f08a9-4d4b-4696-ba7e-eee702b0ded6", - "metadata": {}, - "source": [ - "## Get data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "94627389-7636-4555-a27b-457d3980fad5", - "metadata": {}, - "outputs": [], - "source": [ - "admin_data, x, y = get_gemini_data(PATH)" - ] - }, - { - "cell_type": "markdown", - "id": "2db722df-32bf-4533-937b-de9f6c82a2fa", - "metadata": {}, - "source": [ - "## Get prediction model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "15a1d71e-19f8-4d70-bd65-c4c4f7e3124e", - "metadata": {}, - "outputs": [], - "source": [ - "output_dim = 1\n", - "batch_size = 64\n", - "input_dim = 108\n", - "timesteps = 6\n", - "hidden_dim = 64\n", - "layer_dim = 2\n", - "dropout = 0.2\n", - "last_timestep_only = False\n", - "\n", - "device = get_device()\n", - "\n", - "model_params = {\n", - " \"device\": device,\n", - " \"input_dim\": input_dim,\n", - " \"hidden_dim\": hidden_dim,\n", - " \"layer_dim\": layer_dim,\n", - " \"output_dim\": output_dim,\n", - " \"dropout_prob\": dropout,\n", - " \"last_timestep_only\": last_timestep_only,\n", - "}\n", - "\n", - "model = get_temporal_model(\"lstm\", model_params).to(device)\n", - "model, optimizer, n_epochs = load_ckp(MODEL_PATH, model)" - ] - }, - { - "cell_type": "markdown", - "id": "6ee33c06-61f3-4e6e-bb8c-556a8b73e967", - "metadata": {}, - "source": [ - "## Rolling window" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2254a5f9-76cd-4f1a-bd76-b9127ed5bdfe", - "metadata": {}, - "outputs": [], - "source": [ - "DR_TECHNIQUE = \"BBSDs_trained_LSTM\"\n", - "MD_TEST = \"mmd\"\n", - "SAMPLE = 1000\n", - "CONTEXT_TYPE = \"lstm\"\n", - "PROJ_TYPE = \"lstm\"\n", - "START_DATE = date(2019, 1, 1)\n", - "END_DATE = date(2020, 8, 1)" - ] - }, - { - "cell_type": "markdown", - "id": "7734cf1f-fa8d-49e1-aff7-7effe912bd8e", - "metadata": {}, - "source": [ - "## Hospital type experiment over time" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5c83d394-e0ac-4668-b524-a7fa7816b452", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Set constant reference distribution\n", - "random.seed(1)\n", - "print(\"Query data %s ...\" % SHIFT)\n", - "\n", - "(X_tr, y_tr), (X_val, y_val), (X_t, y_t), feats, admin_data = import_dataset_hospital(\n", - " admin_data,\n", - " x,\n", - " y,\n", - " SHIFT,\n", - " OUTCOME,\n", - " HOSPITALS,\n", - ")\n", - "\n", - "# Normalize data\n", - "X_tr_normalized = normalize(admin_data, X_tr, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_val_normalized = normalize(admin_data, X_val, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_normalized = normalize(admin_data, X_t, AGGREGATION_TYPE, TIMESTEPS)\n", - "\n", - "# Get labels\n", - "y_tr = get_label(admin_data, X_tr, OUTCOME)\n", - "y_val = get_label(admin_data, X_val, OUTCOME)\n", - "y_t = get_label(admin_data, X_t, OUTCOME)\n", - "\n", - "# Scale data\n", - "X_tr_scaled = scale(X_tr_normalized)\n", - "X_val_scaled = scale(X_val_normalized)\n", - "X_t_scaled = scale(X_t_normalized)\n", - "\n", - "# Process data\n", - "X_tr_final = process(X_tr_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_val_final = process(X_val_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_final = process(X_t_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "\n", - "train_ids = list(X_tr_normalized.index.get_level_values(0).unique())\n", - "val_ids = list(X_val_normalized.index.get_level_values(0).unique())\n", - "exclude_ids = train_ids + val_ids\n", - "\n", - "print(\"Get target data streams...\")\n", - "data_streams = get_serving_data(\n", - " x,\n", - " y,\n", - " admin_data,\n", - " START_DATE,\n", - " END_DATE,\n", - " stride=1,\n", - " window=1,\n", - " ids_to_exclude=exclude_ids,\n", - " encounter_id=\"encounter_id\",\n", - " admit_timestamp=\"admit_timestamp\",\n", - ")\n", - "\n", - "print(\"Get Shift Reductor...\")\n", - "reductor = Reductor(\n", - " dr_method=DR_TECHNIQUE,\n", - " model_path=MODEL_PATH,\n", - " n_features=len(feats),\n", - " var_ret=0.8,\n", - ")\n", - "\n", - "print(\"Get Shift Tester...\")\n", - "tester = TSTester(\n", - " tester_method=MD_TEST,\n", - ")\n", - "\n", - "print(\"Get Shift Detector...\")\n", - "detector = Detector(\n", - " reductor=reductor,\n", - " tester=tester,\n", - " p_val_threshold=0.05,\n", - ")\n", - "\n", - "detector.fit(X_val_final)\n", - "\n", - "print(\"Get Rolling Window...\")\n", - "\n", - "rolling_window = RollingWindow(shift_detector=detector, optimizer=optimizer)\n", - "\n", - "drift_metrics = rolling_window.drift(\n", - " data_streams,\n", - " SAMPLE,\n", - " STAT_WINDOW,\n", - " LOOKUP_WINDOW,\n", - " STRIDE,\n", - ")\n", - "\n", - "performance_metrics = rolling_window.performance(\n", - " data_streams,\n", - " STAT_WINDOW,\n", - " LOOKUP_WINDOW,\n", - " STRIDE,\n", - ")\n", - "\n", - "results = {\n", - " \"timestamps\": [\n", - " (\n", - " datetime.datetime.strptime(date, \"%Y-%m-%d\")\n", - " + datetime.timedelta(days=LOOKUP_WINDOW + STAT_WINDOW)\n", - " ).strftime(\"%Y-%m-%d\")\n", - " for date in data_streams[\"timestamps\"]\n", - " ][:-STAT_WINDOW],\n", - "}\n", - "results.update(drift_metrics)\n", - "results.update(performance_metrics)\n", - "results.to_pickle(\n", - " os.path.join(PATH, SHIFT + \"_\" + DR_TECHNIQUE + \"_\" + MD_TEST + \"_results.pkl\"),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "89f8afbf-0f7f-4806-aed8-22fdc6d79315", - "metadata": {}, - "outputs": [], - "source": [ - "threshold = 0.05\n", - "sig_drift = np.array(results[\"shift_detected\"])[np.newaxis]\n", - "\n", - "fig, (ax1, ax2, ax3, ax4, ax5, ax6) = plt.subplots(6, 1, figsize=(18, 12))\n", - "cmap = ListedColormap([\"lightgrey\", \"red\"])\n", - "ax1.plot(\n", - " results[\"timestamps\"],\n", - " results[\"p_val\"],\n", - " \".-\",\n", - " color=\"red\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax1.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax1.axhline(y=threshold, color=\"dimgrey\", linestyle=\"--\")\n", - "ax1.set_ylabel(\"P-Values\", fontsize=16)\n", - "ax1.set_xticklabels([])\n", - "ax1.pcolorfast(ax1.get_xlim(), ax1.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax2.plot(\n", - " results[\"timestamps\"],\n", - " results[\"distance\"],\n", - " \".-\",\n", - " color=\"red\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax2.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax2.set_ylabel(\"Distance\", fontsize=16)\n", - "ax2.axhline(y=np.mean(results[\"distance\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax2.set_xticklabels([])\n", - "ax2.pcolorfast(ax2.get_xlim(), ax2.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax3.plot(\n", - " results[\"timestamps\"],\n", - " results[\"auroc\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax3.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax3.set_ylabel(\"AUROC\", fontsize=16)\n", - "ax3.axhline(y=np.mean(results[\"auroc\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax3.set_xticklabels([])\n", - "ax3.pcolorfast(ax3.get_xlim(), ax3.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax4.plot(\n", - " results[\"timestamps\"],\n", - " results[\"auprc\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax4.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax4.set_ylabel(\"AUPRC\", fontsize=16)\n", - "ax4.axhline(y=np.mean(results[\"auprc\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax4.set_xticklabels([])\n", - "ax4.pcolorfast(ax4.get_xlim(), ax4.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax5.plot(\n", - " results[\"timestamps\"],\n", - " results[\"prec1\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax5.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax5.set_ylabel(\"PPV\", fontsize=16)\n", - "ax5.axhline(y=np.mean(results[\"prec1\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax5.set_xticklabels([])\n", - "ax5.pcolorfast(ax5.get_xlim(), ax5.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax6.plot(\n", - " results[\"timestamps\"],\n", - " results[\"rec1\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax6.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax6.set_ylabel(\"Sensitivity\", fontsize=16)\n", - "ax6.set_xlabel(\"time (s)\", fontsize=16)\n", - "ax6.axhline(y=np.mean(results[\"rec1\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax6.tick_params(axis=\"x\", labelrotation=45)\n", - "ax6.pcolorfast(ax6.get_xlim(), ax6.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "for index, label in enumerate(ax6.xaxis.get_ticklabels()):\n", - " if index % 28 != 0:\n", - " label.set_visible(False)\n", - "\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.9.7 ('cyclops-4J2PL5I8-py3.9')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - }, - "vscode": { - "interpreter": { - "hash": "bd2cd438e1c6ddffa3035fc73b17ac5cc0e0ea8897eb8be17cc645c6abf0c8cc" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/nbs/monitor/rolling_window/static.ipynb b/nbs/monitor/rolling_window/static.ipynb deleted file mode 100644 index 155cc51f8..000000000 --- a/nbs/monitor/rolling_window/static.ipynb +++ /dev/null @@ -1,813 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "1937bafb-82ae-4a69-afec-89c3ee787ff5", - "metadata": {}, - "source": [ - "### Evaluating performance and drift of in hospital mortality model (xgboost) using a rolling window" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9a0204ce-4c3e-44d4-ba4e-87700c720acf", - "metadata": {}, - "outputs": [], - "source": [ - "import datetime\n", - "import os\n", - "import pickle\n", - "import random\n", - "from datetime import date\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "import scipy.stats as st\n", - "from drift_detection.baseline_models.static.utils import run_model\n", - "from drift_detector.detector import Detector\n", - "from drift_detector.reductor import Reductor\n", - "from drift_detector.rolling_window import RollingWindow, get_label\n", - "from drift_detector.tester import TSTester\n", - "from gemini.query import get_gemini_data\n", - "from gemini.utils import import_dataset_hospital, normalize, process, scale\n", - "from matplotlib.colors import ListedColormap\n", - "from scipy.stats import pearsonr, spearmanr\n", - "\n", - "from cyclops.monitor.utils import get_serving_data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9dc229b2-82b7-45f6-808a-f03b438f09ba", - "metadata": {}, - "outputs": [], - "source": [ - "PATH = \"/mnt/nfs/project/delirium/drift_exp/JULY-04-2022/\"\n", - "TIMESTEPS = 6\n", - "AGGREGATION_TYPE = \"time_flatten\"\n", - "HOSPITALS = [\"SMH\", \"MSH\", \"THPC\", \"THPM\", \"UHNTG\", \"UHNTW\", \"PMH\", \"SBK\"]\n", - "OUTCOME = \"mortality\"\n", - "THRESHOLD = 0.05\n", - "NUM_TIMESTEPS = 6\n", - "STAT_WINDOW = 30\n", - "LOOKUP_WINDOW = 0\n", - "STRIDE = 1\n", - "\n", - "SHIFT = input(\"Select experiment: \")\n", - "MODEL_PATH = os.path.join(PATH, \"saved_models\", SHIFT + \"_lstm.pt\")\n", - "\n", - "if SHIFT == \"simulated_deployment\":\n", - " exp_params = {\n", - " \"source\": [datetime.date(2015, 1, 1), datetime.date(2019, 1, 1)],\n", - " \"target\": [datetime.date(2019, 1, 1), datetime.date(2020, 8, 1)],\n", - " \"shift_type\": \"source_target\",\n", - " }\n", - "\n", - "if SHIFT == \"covid\":\n", - " exp_params = {\n", - " \"source\": [datetime.date(2019, 1, 1), datetime.date(2020, 2, 1)],\n", - " \"target\": [datetime.date(2020, 3, 1), datetime.date(2020, 8, 1)],\n", - " \"shift_type\": \"time\",\n", - " }\n", - "\n", - "if SHIFT == \"seasonal_summer\":\n", - " exp_params = {\n", - " \"source\": [1, 2, 3, 4, 5, 10, 11, 12],\n", - " \"target\": [6, 7, 8, 9],\n", - " \"shift_type\": \"month\",\n", - " }\n", - "\n", - "if SHIFT == \"seasonal_winter\":\n", - " exp_params = {\n", - " \"source\": [3, 4, 5, 6, 7, 8, 9, 10],\n", - " \"target\": [11, 12, 1, 2],\n", - " \"shift_type\": \"month\",\n", - " }" - ] - }, - { - "cell_type": "markdown", - "id": "1a740e10-af30-42c7-92af-7f470b919798", - "metadata": {}, - "source": [ - "## Get data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "94627389-7636-4555-a27b-457d3980fad5", - "metadata": {}, - "outputs": [], - "source": [ - "admin_data, x, y = get_gemini_data(PATH)" - ] - }, - { - "cell_type": "markdown", - "id": "3508a128-d17a-4601-b97e-18628c9bdee2", - "metadata": {}, - "source": [ - "## Set constant reference distribution" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cb25fc74-11f8-4ff7-89dc-93b5e5164d32", - "metadata": {}, - "outputs": [], - "source": [ - "random.seed(1)\n", - "\n", - "(X_tr, y_tr), (X_val, y_val), (X_t, y_t), feats, admin_data = import_dataset_hospital(\n", - " admin_data,\n", - " x,\n", - " y,\n", - " SHIFT,\n", - " OUTCOME,\n", - " HOSPITALS,\n", - ")\n", - "\n", - "# Normalize data\n", - "X_tr_normalized = normalize(admin_data, X_tr, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_val_normalized = normalize(admin_data, X_val, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_normalized = normalize(admin_data, X_t, AGGREGATION_TYPE, TIMESTEPS)\n", - "\n", - "if AGGREGATION_TYPE != \"time\":\n", - " y_tr = get_label(admin_data, X_tr, OUTCOME)\n", - " y_val = get_label(admin_data, X_val, OUTCOME)\n", - " y_t = get_label(admin_data, X_t, OUTCOME)\n", - "\n", - "# Scale data\n", - "X_tr_scaled = scale(X_tr_normalized)\n", - "X_val_scaled = scale(X_val_normalized)\n", - "X_t_scaled = scale(X_t_normalized)\n", - "\n", - "# Process data\n", - "X_tr_final = process(X_tr_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_val_final = process(X_val_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_final = process(X_t_scaled, AGGREGATION_TYPE, TIMESTEPS)" - ] - }, - { - "cell_type": "markdown", - "id": "875c1032-4b87-4c8b-bc89-ffcc311259c0", - "metadata": {}, - "source": [ - "## Create data streams" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4213b586-46d5-4607-ae79-7a262e252ac1", - "metadata": {}, - "outputs": [], - "source": [ - "START_DATE = date(2019, 1, 1)\n", - "END_DATE = date(2020, 8, 1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "02b62f72-3384-4686-b661-81ebdb7392fb", - "metadata": {}, - "outputs": [], - "source": [ - "print(\"Get target data streams...\")\n", - "data_streams = get_serving_data(\n", - " x,\n", - " y,\n", - " admin_data,\n", - " START_DATE,\n", - " END_DATE,\n", - " stride=1,\n", - " window=1,\n", - " encounter_id=\"encounter_id\",\n", - " admit_timestamp=\"admit_timestamp\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "de8fcfb0-9637-4eb9-a1e1-4e4463ba537a", - "metadata": {}, - "source": [ - "## Get prediction model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7ae5be77-cd44-4e8a-9c3d-402196327952", - "metadata": {}, - "outputs": [], - "source": [ - "MODEL_NAME = input(\"Select Model: \")\n", - "MODEL_PATH = PATH + \"_\".join([SHIFT, OUTCOME, \"_\".join(HOSPITALS), MODEL_NAME]) + \".pkl\"\n", - "if os.path.exists(MODEL_PATH):\n", - " optimised_model = pickle.load(open(MODEL_PATH, \"rb\"))\n", - "else:\n", - " optimised_model = run_model(MODEL_NAME, X_tr_final, y_tr, X_val_final, y_val)\n", - " pickle.dump(optimised_model, open(MODEL_PATH, \"wb\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e39fdedb-5787-466a-979a-6c3c3f305b73", - "metadata": {}, - "outputs": [], - "source": [ - "DR_TECHNIQUE = \"BBSDs_untrained_FFNN\"\n", - "MD_TEST = \"mmd\"\n", - "SAMPLE = 1000\n", - "CONTEXT_TYPE = \"ffnn\"\n", - "PROJ_TYPE = \"ffnn\"\n", - "\n", - "print(\"Get Shift Reductor...\")\n", - "reductor = Reductor(\n", - " dr_method=DR_TECHNIQUE,\n", - " n_features=len(feats) * TIMESTEPS,\n", - " var_ret=0.8,\n", - ")\n", - "\n", - "print(\"Get Shift Tester...\")\n", - "tester = TSTester(tester_method=MD_TEST)\n", - "\n", - "print(\"Get Shift Detector...\")\n", - "detector = Detector(\n", - " reductor=reductor,\n", - " tester=tester,\n", - " p_val_threshold=0.05,\n", - ")\n", - "detector.fit(X_val_final)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7ec76cae-eb78-4566-9eee-b335ccee81c9", - "metadata": {}, - "outputs": [], - "source": [ - "print(\"Get Rolling Window...\")\n", - "\n", - "rolling_window = RollingWindow(\n", - " admin_data=admin_data,\n", - " shift_detector=detector,\n", - " model=optimised_model,\n", - ")\n", - "\n", - "all_runs = []\n", - "for i in range(0, 1):\n", - " random.seed(1)\n", - " np.random.seed(1)\n", - "\n", - " drift_metrics = rolling_window.drift(\n", - " data_streams=data_streams,\n", - " sample=SAMPLE,\n", - " stat_window=STAT_WINDOW,\n", - " lookup_window=LOOKUP_WINDOW,\n", - " stride=STRIDE,\n", - " model_path=MODEL_PATH,\n", - " context_type=CONTEXT_TYPE,\n", - " proj_type=PROJ_TYPE,\n", - " aggregation_type=AGGREGATION_TYPE,\n", - " )\n", - "\n", - " performance_metrics = rolling_window.performance(\n", - " data_streams=data_streams,\n", - " stat_window=STAT_WINDOW,\n", - " lookup_window=LOOKUP_WINDOW,\n", - " stride=STRIDE,\n", - " aggregation_type=AGGREGATION_TYPE,\n", - " )\n", - "\n", - " results = {\n", - " \"timestamps\": [\n", - " (\n", - " datetime.datetime.strptime(date, \"%Y-%m-%d\")\n", - " + datetime.timedelta(days=LOOKUP_WINDOW + STAT_WINDOW)\n", - " ).strftime(\"%Y-%m-%d\")\n", - " for date in data_streams[\"timestamps\"]\n", - " ][:-STAT_WINDOW],\n", - " }\n", - " results.update(drift_metrics)\n", - " results.update(performance_metrics)\n", - "\n", - " all_runs.append(results)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "65591c18-6d08-4f1e-86e6-2bf6e3fc7b27", - "metadata": {}, - "outputs": [], - "source": [ - "avgDict = {}\n", - "for k, v in results.items():\n", - " if not all(isinstance(s, str) for s in v):\n", - " mean = sum(v) / float(len(v))\n", - " ci = st.t.interval(0.95, len(v), loc=np.mean(v), scale=st.sem(v))\n", - " avgDict[k] = [mean, ci]\n", - "avgDict" - ] - }, - { - "cell_type": "markdown", - "id": "3bfa1dee-dce1-461e-958c-7870e4050984", - "metadata": {}, - "source": [ - "## Plot Drift and Prediction Performance" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a9c4178f-53af-4787-957e-95dc1b37912d", - "metadata": {}, - "outputs": [], - "source": [ - "threshold = 0.05\n", - "sig_drift = np.array(results[\"shift_detected\"])[np.newaxis]\n", - "\n", - "fig, (ax1, ax2, ax3, ax4, ax5, ax6) = plt.subplots(6, 1, figsize=(18, 12))\n", - "cmap = ListedColormap([\"lightgrey\", \"red\"])\n", - "ax1.plot(\n", - " results[\"timestamps\"],\n", - " results[\"p_val\"],\n", - " \".-\",\n", - " color=\"red\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax1.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax1.axhline(y=threshold, color=\"dimgrey\", linestyle=\"--\")\n", - "ax1.set_ylabel(\"P-Values\", fontsize=16)\n", - "ax1.set_xticklabels([])\n", - "ax1.pcolorfast(ax1.get_xlim(), ax1.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax2.plot(\n", - " results[\"timestamps\"],\n", - " results[\"distance\"],\n", - " \".-\",\n", - " color=\"red\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax2.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax2.set_ylabel(\"Distance\", fontsize=16)\n", - "ax2.axhline(y=np.mean(results[\"distance\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax2.set_xticklabels([])\n", - "ax2.pcolorfast(ax2.get_xlim(), ax2.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax3.plot(\n", - " results[\"timestamps\"],\n", - " results[\"auroc\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax3.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax3.set_ylabel(\"AUROC\", fontsize=16)\n", - "ax3.axhline(y=np.mean(results[\"auroc\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax3.set_xticklabels([])\n", - "ax3.pcolorfast(ax3.get_xlim(), ax3.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax4.plot(\n", - " results[\"timestamps\"],\n", - " results[\"auprc\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax4.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax4.set_ylabel(\"AUPRC\", fontsize=16)\n", - "ax4.axhline(y=np.mean(results[\"auprc\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax4.set_xticklabels([])\n", - "ax4.pcolorfast(ax4.get_xlim(), ax4.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax5.plot(\n", - " results[\"timestamps\"],\n", - " results[\"prec1\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax5.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax5.set_ylabel(\"PPV\", fontsize=16)\n", - "ax5.axhline(y=np.mean(results[\"prec1\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax5.set_xticklabels([])\n", - "ax5.pcolorfast(ax5.get_xlim(), ax5.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax6.plot(\n", - " results[\"timestamps\"],\n", - " results[\"rec1\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax6.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax6.set_ylabel(\"Sensitivity\", fontsize=16)\n", - "ax6.set_xlabel(\"time (s)\", fontsize=16)\n", - "ax6.axhline(y=np.mean(results[\"rec1\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax6.tick_params(axis=\"x\", labelrotation=45)\n", - "ax6.pcolorfast(ax6.get_xlim(), ax6.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "for index, label in enumerate(ax6.xaxis.get_ticklabels()):\n", - " if index % 28 != 0:\n", - " label.set_visible(False)\n", - "\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "9d883afb-2a5a-4c09-8fb0-f80c264a7f49", - "metadata": {}, - "source": [ - "## Retraining: Drift Alarms " - ] - }, - { - "cell_type": "markdown", - "id": "c39a259d-0394-48f8-b52a-e58c25f05368", - "metadata": {}, - "source": [ - "### Drift Alarms" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2af5b979-7243-4675-97a7-dcc4b47ae7e8", - "metadata": {}, - "outputs": [], - "source": [ - "baseline = [127, 118, 119, 123, 127]\n", - "mostrecent30 = [132, 116, 97, 98, 128]\n", - "mostrecent60 = [100, 96, 108, 97, 97]\n", - "mostrecent120 = [96, 76, 101, 67, 89]\n", - "cumulative = [72, 112, 64, 85, 107]\n", - "\n", - "retraining_drift = pd.DataFrame(\n", - " {\n", - " \"Baseline\": baseline,\n", - " \"Most Recent \\n(30 days)\": mostrecent30,\n", - " \"Most Recent \\n(60 days)\": mostrecent60,\n", - " \"Most Recent \\n(120 days)\": mostrecent120,\n", - " \"Cumulative\": cumulative,\n", - " },\n", - ")\n", - "fig, ax = plt.subplots(figsize=(7, 4))\n", - "ax.boxplot(retraining_drift, patch_artist=True)\n", - "ax.set_xticks([1, 2, 3, 4, 5], retraining_drift.columns, rotation=45, fontsize=12)\n", - "ax.set_xlabel(\"Retraining Strategies\", fontsize=12)\n", - "ax.set_ylabel(\"Number of Drift Alarms\", fontsize=12)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "52d780a7-1c4d-4466-93b8-a72af9c3eb57", - "metadata": {}, - "source": [ - "### Number of Epochs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "414f713f-ed48-49e6-8e01-51fbb62888cd", - "metadata": {}, - "outputs": [], - "source": [ - "baseline = [127, 118, 119, 123, 127]\n", - "mostrecent120 = [96, 76, 101, 67, 89]\n", - "mostrecent120_10 = [97, 103, 98, 64, 94]\n", - "\n", - "retraining_drift = pd.DataFrame(\n", - " {\n", - " \"Baseline\": baseline,\n", - " \"Most Recent \\n(120 days, 1 epoch)\": mostrecent120,\n", - " \"Most Recent\\n (120 days, 10 epochs)\": mostrecent120_10,\n", - " },\n", - ")\n", - "fig, ax = plt.subplots(figsize=(7, 4))\n", - "ax.boxplot(retraining_drift, patch_artist=True)\n", - "ax.set_xticks([1, 2, 3], retraining_drift.columns, rotation=45, fontsize=12)\n", - "ax.set_xlabel(\"Retraining Strategies\", fontsize=12)\n", - "ax.set_ylabel(\"Number of Drift Alarms\", fontsize=12)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "0aceca93-ac9a-41a9-8c1a-a2702acbe587", - "metadata": {}, - "source": [ - "### Drift Threshold" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "52a00d0e-3c6b-4527-a0cb-fd2cc31a02ab", - "metadata": {}, - "outputs": [], - "source": [ - "mostrecent120_10_2 = [50, 44, 40, 51, 61]\n", - "mostrecent120 = [96, 76, 101, 67, 89]\n", - "mostrecent120_10_1 = [121, 150, 123, 139, 131]\n", - "\n", - "retraining_drift = pd.DataFrame(\n", - " {\n", - " \"P-Val=0.01\": mostrecent120_10_2,\n", - " \"P-Val=0.05\": mostrecent120,\n", - " \"P-Val=0.1\": mostrecent120_10_1,\n", - " },\n", - ")\n", - "fig, ax = plt.subplots(figsize=(7, 4))\n", - "ax.boxplot(retraining_drift, patch_artist=True)\n", - "ax.set_xticks([1, 2, 3], retraining_drift.columns, rotation=45, fontsize=12)\n", - "ax.set_xlabel(\"Retraining Strategies\", fontsize=12)\n", - "ax.set_ylabel(\"Number of Drift Alarms\", fontsize=12)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "f972b02a-f2f4-4e84-8891-a1bc650d34b2", - "metadata": {}, - "source": [ - "## Retraining: PPV & Sensitivity" - ] - }, - { - "cell_type": "markdown", - "id": "44d7cbaf-d645-40b3-9e0f-a8a324298c23", - "metadata": {}, - "source": [ - "### Window Size" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d22682c3-5501-44cc-96bf-4099eb9622f1", - "metadata": {}, - "outputs": [], - "source": [ - "types = [\n", - " \"baseline\",\n", - " \"mostrecent30\",\n", - " \"mostrecent60\",\n", - " \"mostrecent120\",\n", - " \"cumulative_1epoch\",\n", - "]\n", - "labels = [\n", - " \"Baseline\",\n", - " \"Most Recent\\n(30days)\",\n", - " \"Most Recent\\n(60 days)\",\n", - " \"Most Recent\\n(120 days)\",\n", - " \"Cumulative\",\n", - "]\n", - "\n", - "drift_sensitivity = []\n", - "drift_ppv = []\n", - "for retraining_type in types:\n", - " for i in range(0, 5):\n", - " res_path = os.path.join(\n", - " PATH,\n", - " SHIFT,\n", - " SHIFT + \"_\" + retraining_type + \"_retraining_update.npy\",\n", - " )\n", - " cum = np.load(res_path, allow_pickle=True)[i]\n", - " drift_sensitivity.append(np.mean(cum[\"performance\"][\"rec1\"]))\n", - " drift_ppv.append(np.mean(cum[\"performance\"][\"prec1\"]))\n", - " # drift_sensitivity.append(np.mean(cum['performance']\n", - " # ['rec1'][[i for i,v in enumerate(cum['pval']) if v < 0.05]]))\n", - " # drift_ppv.append(np.mean(cum['performance']['prec1'][\n", - " # [i for i,v in enumerate(cum['pval']) if v < 0.05]]))\n", - "\n", - "retraining_drift = pd.DataFrame(\n", - " {\n", - " \"Retraining Strategy\": np.repeat(types, 5),\n", - " \"PPV\": drift_ppv,\n", - " \"Sensitivity\": drift_sensitivity,\n", - " },\n", - ")\n", - "\n", - "fig, axs = plt.subplots(1, 2, figsize=(14, 4))\n", - "for j, variable in enumerate([\"PPV\", \"Sensitivity\"]):\n", - " for i, grp in enumerate(retraining_drift.groupby(\"Retraining Strategy\")):\n", - " axs[j].boxplot(\n", - " x=variable,\n", - " data=grp[1],\n", - " positions=[i],\n", - " widths=0.4,\n", - " patch_artist=True,\n", - " )\n", - " axs[j].set_xticks(range(0, len(types)), labels, rotation=45, fontsize=12)\n", - " axs[j].set_ylabel(variable, fontsize=12)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "68e011a5-92d0-4ebf-adca-83ae40c32f07", - "metadata": {}, - "source": [ - "### Number of Epochs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "395c27bf-4553-473c-8fe4-f95e99a03ad4", - "metadata": {}, - "outputs": [], - "source": [ - "types = [\"baseline\", \"mostrecent120\", \"mostrecent120_10epochs\"]\n", - "labels = [\"Baseline\", \"Most Recent\\n(120 days)\", \"Most Recent\\n(120 days, 10 epochs)\"]\n", - "\n", - "drift_sensitivity = []\n", - "drift_ppv = []\n", - "for retraining_type in types:\n", - " for i in range(0, 5):\n", - " res_path = os.path.join(\n", - " PATH,\n", - " SHIFT,\n", - " SHIFT + \"_\" + retraining_type + \"_retraining_update.npy\",\n", - " )\n", - " cum = np.load(res_path, allow_pickle=True)[i]\n", - " drift_sensitivity.append(np.mean(cum[\"performance\"][\"rec1\"]))\n", - " drift_ppv.append(np.mean(cum[\"performance\"][\"prec1\"]))\n", - " # drift_sensitivity.append(np.mean(cum['performance']\n", - " # ['rec1'][[i for i,v in enumerate(cum['pval']) if v < 0.05]]))\n", - " # drift_ppv.append(np.mean(cum['performance']['prec1']\n", - " # [[i for i,v in enumerate(cum['pval']) if v < 0.05]]))\n", - "\n", - "retraining_drift = pd.DataFrame(\n", - " {\n", - " \"Retraining Strategy\": np.repeat(types, 5),\n", - " \"PPV\": drift_ppv,\n", - " \"Sensitivity\": drift_sensitivity,\n", - " },\n", - ")\n", - "\n", - "fig, axs = plt.subplots(1, 2, figsize=(14, 4))\n", - "for j, variable in enumerate([\"PPV\", \"Sensitivity\"]):\n", - " for i, grp in enumerate(retraining_drift.groupby(\"Retraining Strategy\")):\n", - " axs[j].boxplot(\n", - " x=variable,\n", - " data=grp[1],\n", - " positions=[i],\n", - " widths=0.4,\n", - " patch_artist=True,\n", - " )\n", - " axs[j].set_xticks(range(0, len(types)), labels, rotation=45, fontsize=12)\n", - " axs[j].set_ylabel(variable, fontsize=12)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "97902e7f-d9c4-4c05-81e2-57ba1dde34e0", - "metadata": {}, - "source": [ - "### Drift Threshold" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9c58a8e0-e101-4ecc-a1e8-e04fd366cd60", - "metadata": {}, - "outputs": [], - "source": [ - "types = [\n", - " \"mostrecent120_1epoch_pval0.01\",\n", - " \"mostrecent120\",\n", - " \"mostrecent120_1epoch_pval0.1\",\n", - "]\n", - "labels = [\"P-Val=0.01\", \"P-Val=0.05\", \"P-Val=0.1\"]\n", - "\n", - "drift_sensitivity = []\n", - "drift_ppv = []\n", - "for retraining_type in types:\n", - " for i in range(0, 5):\n", - " res_path = os.path.join(\n", - " PATH,\n", - " SHIFT,\n", - " SHIFT + \"_\" + retraining_type + \"_retraining_update.npy\",\n", - " )\n", - " cum = np.load(res_path, allow_pickle=True)[i]\n", - " drift_sensitivity.append(np.mean(cum[\"performance\"][\"rec1\"]))\n", - " drift_ppv.append(np.mean(cum[\"performance\"][\"prec1\"]))\n", - " # drift_sensitivity.append(np.mean(cum['performance']\n", - " # ['rec1'][[i for i,v in enumerate(cum['pval']) if v < 0.05]]))\n", - " # drift_ppv.append(np.mean(cum['performance']['prec1']\n", - " # [[i for i,v in enumerate(cum['pval']) if v < 0.05]]))\n", - "\n", - "retraining_drift = pd.DataFrame(\n", - " {\n", - " \"Retraining Strategy\": np.repeat(types, 5),\n", - " \"PPV\": drift_ppv,\n", - " \"Sensitivity\": drift_sensitivity,\n", - " },\n", - ")\n", - "\n", - "fig, axs = plt.subplots(1, 2, figsize=(14, 4))\n", - "for j, variable in enumerate([\"PPV\", \"Sensitivity\"]):\n", - " for i, grp in enumerate(retraining_drift.groupby(\"Retraining Strategy\")):\n", - " axs[j].boxplot(\n", - " x=variable,\n", - " data=grp[1],\n", - " positions=[i],\n", - " widths=0.4,\n", - " patch_artist=True,\n", - " )\n", - " axs[j].set_xticks(range(0, len(types)), labels, rotation=45, fontsize=12)\n", - " axs[j].set_ylabel(variable, fontsize=12)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "d92f0a10-92fb-479d-a626-c5a28dbf88b8", - "metadata": {}, - "source": [ - "## Relationship between Performance and Drift P-Value" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "47ea7822-b83a-41ee-9bc4-394eb13b296b", - "metadata": {}, - "outputs": [], - "source": [ - "# calculate Pearson's correlation\n", - "pcorr, pcorr_pval = pearsonr(results[\"prec1\"], results[\"pval\"])\n", - "print(\"Pearsons correlation: %.3f P-Value: %.3f\" % (pcorr, pcorr_pval))\n", - "# calculate spearman's correlation\n", - "scorr, scorr_pval = spearmanr(results[\"prec1\"], results[\"pval\"])\n", - "print(\"Spearmans correlation: %.3f P-Value: %.3f\" % (scorr, scorr_pval))\n", - "# plot\n", - "plt.scatter(results[\"prec1\"], results[\"pval\"])\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "115e6398-411d-4c69-8770-9c76f30f3da5", - "metadata": {}, - "outputs": [], - "source": [ - "# calculate Pearson's correlation\n", - "pcorr, pcorr_pval = pearsonr(results[\"rec1\"], results[\"pval\"])\n", - "print(\"Pearsons correlation: %.3f P-Value: %.3f\" % (pcorr, pcorr_pval))\n", - "# calculate spearman's correlation\n", - "scorr, scorr_pval = spearmanr(results[\"rec1\"], results[\"pval\"])\n", - "print(\"Spearmans correlation: %.3f P-Value: %.3f\" % (scorr, scorr_pval))\n", - "# plot\n", - "plt.scatter(results[\"rec1\"], results[\"pval\"])\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "cyclops-KKtuQLwg-py3.9", - "language": "python", - "name": "cyclops-kktuqlwg-py3.9" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/nbs/monitor/rolling_window/temporal.ipynb b/nbs/monitor/rolling_window/temporal.ipynb deleted file mode 100644 index 2e363468c..000000000 --- a/nbs/monitor/rolling_window/temporal.ipynb +++ /dev/null @@ -1,829 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "4d73413b-4e2b-44cc-bba6-385bef6f7894", - "metadata": {}, - "source": [ - "### Evaluating performance and drift of mortality decompensation model (LSTM) using a rolling window" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9a0204ce-4c3e-44d4-ba4e-87700c720acf", - "metadata": {}, - "outputs": [], - "source": [ - "import datetime\n", - "import os\n", - "import random\n", - "from datetime import date\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "import scipy.stats as st\n", - "from baseline_models.temporal.pytorch.utils import get_device, load_ckp\n", - "from drift_detector.detector import Detector\n", - "from drift_detector.reductor import Reductor\n", - "from drift_detector.rolling_window import RollingWindow\n", - "from drift_detector.tester import TSTester\n", - "from drift_detector.utils import get_serving_data, get_temporal_model\n", - "from gemini.query import get_gemini_data\n", - "from gemini.utils import get_label, import_dataset_hospital, normalize, process, scale\n", - "from matplotlib.colors import ListedColormap\n", - "from scipy.stats import pearsonr, spearmanr" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9dc229b2-82b7-45f6-808a-f03b438f09ba", - "metadata": {}, - "outputs": [], - "source": [ - "PATH = \"/mnt/nfs/project/delirium/drift_exp/JULY-04-2022/\"\n", - "TIMESTEPS = 6\n", - "AGGREGATION_TYPE = \"time\"\n", - "HOSPITALS = [\"SMH\", \"MSH\", \"THPC\", \"THPM\", \"UHNTG\", \"UHNTW\", \"PMH\", \"SBK\"]\n", - "OUTCOME = \"mortality\"\n", - "THRESHOLD = 0.05\n", - "NUM_TIMESTEPS = 6\n", - "STAT_WINDOW = 30\n", - "LOOKUP_WINDOW = 0\n", - "STRIDE = 1\n", - "\n", - "SHIFT = input(\"Select experiment: \") # hospital_type\n", - "MODEL_PATH = os.path.join(PATH, \"saved_models\", SHIFT + \"_lstm.pt\")\n", - "\n", - "if SHIFT == \"simulated_deployment\":\n", - " exp_params = {\n", - " \"source\": [datetime.date(2015, 1, 1), datetime.date(2019, 1, 1)],\n", - " \"target\": [datetime.date(2019, 1, 1), datetime.date(2020, 8, 1)],\n", - " \"shift_type\": \"source_target\",\n", - " }\n", - "\n", - "if SHIFT == \"covid\":\n", - " exp_params = {\n", - " \"source\": [datetime.date(2019, 1, 1), datetime.date(2020, 2, 1)],\n", - " \"target\": [datetime.date(2020, 3, 1), datetime.date(2020, 8, 1)],\n", - " \"shift_type\": \"time\",\n", - " }\n", - "\n", - "if SHIFT == \"seasonal_summer\":\n", - " exp_params = {\n", - " \"source\": [1, 2, 3, 4, 5, 10, 11, 12],\n", - " \"target\": [6, 7, 8, 9],\n", - " \"shift_type\": \"month\",\n", - " }\n", - "\n", - "if SHIFT == \"seasonal_winter\":\n", - " exp_params = {\n", - " \"source\": [3, 4, 5, 6, 7, 8, 9, 10],\n", - " \"target\": [11, 12, 1, 2],\n", - " \"shift_type\": \"month\",\n", - " }" - ] - }, - { - "cell_type": "markdown", - "id": "1a740e10-af30-42c7-92af-7f470b919798", - "metadata": {}, - "source": [ - "## Get data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "94627389-7636-4555-a27b-457d3980fad5", - "metadata": {}, - "outputs": [], - "source": [ - "admin_data, x, y = get_gemini_data(PATH)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cb25fc74-11f8-4ff7-89dc-93b5e5164d32", - "metadata": {}, - "outputs": [], - "source": [ - "random.seed(1)\n", - "\n", - "(X_tr, y_tr), (X_val, y_val), (X_t, y_t), feats, admin_data = import_dataset_hospital(\n", - " admin_data,\n", - " x,\n", - " y,\n", - " SHIFT,\n", - " OUTCOME,\n", - " HOSPITALS,\n", - ")\n", - "\n", - "# Normalize data\n", - "X_tr_normalized = normalize(admin_data, X_tr, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_val_normalized = normalize(admin_data, X_val, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_normalized = normalize(admin_data, X_t, AGGREGATION_TYPE, TIMESTEPS)\n", - "\n", - "if AGGREGATION_TYPE != \"time\":\n", - " y_tr = get_label(admin_data, X_tr, OUTCOME)\n", - " y_val = get_label(admin_data, X_val, OUTCOME)\n", - " y_t = get_label(admin_data, X_t, OUTCOME)\n", - "\n", - "# Scale data\n", - "X_tr_scaled = scale(X_tr_normalized)\n", - "X_val_scaled = scale(X_val_normalized)\n", - "X_t_scaled = scale(X_t_normalized)\n", - "\n", - "# Process data\n", - "X_tr_final = process(X_tr_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_val_final = process(X_val_scaled, AGGREGATION_TYPE, TIMESTEPS)\n", - "X_t_final = process(X_t_scaled, AGGREGATION_TYPE, TIMESTEPS)" - ] - }, - { - "cell_type": "markdown", - "id": "875c1032-4b87-4c8b-bc89-ffcc311259c0", - "metadata": {}, - "source": [ - "## Create data streams" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4213b586-46d5-4607-ae79-7a262e252ac1", - "metadata": {}, - "outputs": [], - "source": [ - "START_DATE = date(2019, 1, 1)\n", - "END_DATE = date(2020, 8, 1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "02b62f72-3384-4686-b661-81ebdb7392fb", - "metadata": {}, - "outputs": [], - "source": [ - "print(\"Get target data streams...\")\n", - "data_streams = get_serving_data(\n", - " x,\n", - " y,\n", - " admin_data,\n", - " START_DATE,\n", - " END_DATE,\n", - " stride=1,\n", - " window=1,\n", - " encounter_id=\"encounter_id\",\n", - " admit_timestamp=\"admit_timestamp\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "de8fcfb0-9637-4eb9-a1e1-4e4463ba537a", - "metadata": {}, - "source": [ - "## Get prediction model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7ae5be77-cd44-4e8a-9c3d-402196327952", - "metadata": {}, - "outputs": [], - "source": [ - "output_dim = 1\n", - "input_dim = 108\n", - "hidden_dim = 64\n", - "layer_dim = 2\n", - "dropout = 0.2\n", - "last_timestep_only = False\n", - "device = get_device()\n", - "\n", - "model_params = {\n", - " \"device\": device,\n", - " \"input_dim\": input_dim,\n", - " \"hidden_dim\": hidden_dim,\n", - " \"layer_dim\": layer_dim,\n", - " \"output_dim\": output_dim,\n", - " \"dropout_prob\": dropout,\n", - " \"last_timestep_only\": last_timestep_only,\n", - "}\n", - "\n", - "model = get_temporal_model(\"lstm\", model_params).to(device)\n", - "model, optimizer, n_epochs = load_ckp(MODEL_PATH, model)" - ] - }, - { - "cell_type": "markdown", - "id": "fc9edf9d-f6c1-4729-aea3-530e223c449c", - "metadata": {}, - "source": [ - "## Get shift detector" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e39fdedb-5787-466a-979a-6c3c3f305b73", - "metadata": {}, - "outputs": [], - "source": [ - "DR_TECHNIQUE = \"BBSDs_trained_LSTM\"\n", - "MD_TEST = \"mmd\"\n", - "SAMPLE = 1000\n", - "CONTEXT_TYPE = \"lstm\"\n", - "PROJ_TYPE = \"lstm\"\n", - "\n", - "print(\"Get Shift Reductor...\")\n", - "reductor = Reductor(\n", - " dr_method=DR_TECHNIQUE,\n", - " model_path=MODEL_PATH,\n", - " n_features=len(feats),\n", - " var_ret=0.8,\n", - ")\n", - "\n", - "print(\"Get Shift Tester...\")\n", - "tester = TSTester(tester_method=MD_TEST)\n", - "\n", - "print(\"Get Shift Detector...\")\n", - "detector = Detector(\n", - " reductor=reductor,\n", - " tester=tester,\n", - " p_val_threshold=0.05,\n", - ")\n", - "detector.fit(X_tr_final)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a37d3489-4f90-4360-8166-684281ef3c88", - "metadata": {}, - "outputs": [], - "source": [ - "print(\"Get Rolling Window...\")\n", - "\n", - "rolling_window = RollingWindow(\n", - " admin_data=admin_data,\n", - " shift_detector=detector,\n", - " optimizer=optimizer,\n", - ")\n", - "\n", - "all_runs = []\n", - "for i in range(0, 1):\n", - " random.seed(1)\n", - " np.random.seed(1)\n", - "\n", - " drift_metrics = rolling_window.drift(\n", - " data_streams=data_streams,\n", - " sample=SAMPLE,\n", - " stat_window=STAT_WINDOW,\n", - " lookup_window=LOOKUP_WINDOW,\n", - " stride=STRIDE,\n", - " model_path=MODEL_PATH,\n", - " context_type=CONTEXT_TYPE,\n", - " proj_type=PROJ_TYPE,\n", - " )\n", - "\n", - " performance_metrics = rolling_window.performance(\n", - " data_streams=data_streams,\n", - " stat_window=STAT_WINDOW,\n", - " lookup_window=LOOKUP_WINDOW,\n", - " stride=STRIDE,\n", - " )\n", - "\n", - " results = {\n", - " \"timestamps\": [\n", - " (\n", - " datetime.datetime.strptime(date, \"%Y-%m-%d\")\n", - " + datetime.timedelta(days=LOOKUP_WINDOW + STAT_WINDOW)\n", - " ).strftime(\"%Y-%m-%d\")\n", - " for date in data_streams[\"timestamps\"]\n", - " ][:-STAT_WINDOW],\n", - " }\n", - " results.update(drift_metrics)\n", - " results.update(performance_metrics)\n", - "\n", - " all_runs.append(results)\n", - "np.save(os.path.join(PATH, SHIFT, SHIFT + \"_rolling_window.npy\"), all_runs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "65591c18-6d08-4f1e-86e6-2bf6e3fc7b27", - "metadata": {}, - "outputs": [], - "source": [ - "avgDict = {}\n", - "for k, v in results.items():\n", - " if not all(isinstance(s, str) for s in v):\n", - " mean = sum(v) / float(len(v))\n", - " ci = st.t.interval(0.95, len(v), loc=np.mean(v), scale=st.sem(v))\n", - " avgDict[k] = [mean, ci]\n", - "avgDict" - ] - }, - { - "cell_type": "markdown", - "id": "3bfa1dee-dce1-461e-958c-7870e4050984", - "metadata": {}, - "source": [ - "## Plot Drift and Prediction Performance" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a9c4178f-53af-4787-957e-95dc1b37912d", - "metadata": {}, - "outputs": [], - "source": [ - "p_val_threshold = 0.05\n", - "sig_drift = np.array(results[\"shift_detected\"])[np.newaxis]\n", - "\n", - "fig, (ax1, ax2, ax3, ax4, ax5, ax6) = plt.subplots(6, 1, figsize=(18, 12))\n", - "cmap = ListedColormap([\"lightgrey\", \"red\"])\n", - "ax1.plot(\n", - " results[\"timestamps\"],\n", - " results[\"p_val\"],\n", - " \".-\",\n", - " color=\"red\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax1.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax1.axhline(y=p_val_threshold, color=\"dimgrey\", linestyle=\"--\")\n", - "ax1.set_ylabel(\"P-Values\", fontsize=16)\n", - "ax1.set_xticklabels([])\n", - "ax1.pcolorfast(ax1.get_xlim(), ax1.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax2.plot(\n", - " results[\"timestamps\"],\n", - " results[\"distance\"],\n", - " \".-\",\n", - " color=\"red\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax2.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax2.set_ylabel(\"Distance\", fontsize=16)\n", - "ax2.axhline(y=np.mean(results[\"distance\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax2.set_xticklabels([])\n", - "ax2.pcolorfast(ax2.get_xlim(), ax2.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax3.plot(\n", - " results[\"timestamps\"],\n", - " results[\"auroc\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax3.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax3.set_ylabel(\"AUROC\", fontsize=16)\n", - "ax3.axhline(y=np.mean(results[\"auroc\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax3.set_xticklabels([])\n", - "ax3.pcolorfast(ax3.get_xlim(), ax3.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax4.plot(\n", - " results[\"timestamps\"],\n", - " results[\"auprc\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax4.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax4.set_ylabel(\"AUPRC\", fontsize=16)\n", - "ax4.axhline(y=np.mean(results[\"auprc\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax4.set_xticklabels([])\n", - "ax4.pcolorfast(ax4.get_xlim(), ax4.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax5.plot(\n", - " results[\"timestamps\"],\n", - " results[\"prec1\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax5.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax5.set_ylabel(\"PPV\", fontsize=16)\n", - "ax5.axhline(y=np.mean(results[\"prec1\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax5.set_xticklabels([])\n", - "ax5.pcolorfast(ax5.get_xlim(), ax5.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "ax6.plot(\n", - " results[\"timestamps\"],\n", - " results[\"rec1\"],\n", - " \".-\",\n", - " color=\"blue\",\n", - " linewidth=0.5,\n", - " markersize=2,\n", - ")\n", - "ax6.set_xlim(results[\"timestamps\"][0], results[\"timestamps\"][-1])\n", - "ax6.set_ylabel(\"Sensitivity\", fontsize=16)\n", - "ax6.set_xlabel(\"time (s)\", fontsize=16)\n", - "ax6.axhline(y=np.mean(results[\"rec1\"]), color=\"dimgrey\", linestyle=\"--\")\n", - "ax6.tick_params(axis=\"x\", labelrotation=45)\n", - "ax6.pcolorfast(ax6.get_xlim(), ax6.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)\n", - "\n", - "for index, label in enumerate(ax6.xaxis.get_ticklabels()):\n", - " if index % 28 != 0:\n", - " label.set_visible(False)\n", - "\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "9d883afb-2a5a-4c09-8fb0-f80c264a7f49", - "metadata": {}, - "source": [ - "## Retraining: Drift Alarms " - ] - }, - { - "cell_type": "markdown", - "id": "c39a259d-0394-48f8-b52a-e58c25f05368", - "metadata": {}, - "source": [ - "### Drift Alarms" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2af5b979-7243-4675-97a7-dcc4b47ae7e8", - "metadata": {}, - "outputs": [], - "source": [ - "baseline = [127, 118, 119, 123, 127]\n", - "mostrecent30 = [132, 116, 97, 98, 128]\n", - "mostrecent60 = [100, 96, 108, 97, 97]\n", - "mostrecent120 = [96, 76, 101, 67, 89]\n", - "cumulative = [72, 112, 64, 85, 107]\n", - "\n", - "retraining_drift = pd.DataFrame(\n", - " {\n", - " \"Baseline\": baseline,\n", - " \"Most Recent \\n(30 days)\": mostrecent30,\n", - " \"Most Recent \\n(60 days)\": mostrecent60,\n", - " \"Most Recent \\n(120 days)\": mostrecent120,\n", - " \"Cumulative\": cumulative,\n", - " },\n", - ")\n", - "fig, ax = plt.subplots(figsize=(7, 4))\n", - "ax.boxplot(retraining_drift, patch_artist=True)\n", - "ax.set_xticks([1, 2, 3, 4, 5], retraining_drift.columns, rotation=45, fontsize=12)\n", - "ax.set_xlabel(\"Retraining Strategies\", fontsize=12)\n", - "ax.set_ylabel(\"Number of Drift Alarms\", fontsize=12)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "52d780a7-1c4d-4466-93b8-a72af9c3eb57", - "metadata": {}, - "source": [ - "### Number of Epochs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "414f713f-ed48-49e6-8e01-51fbb62888cd", - "metadata": {}, - "outputs": [], - "source": [ - "baseline = [127, 118, 119, 123, 127]\n", - "mostrecent120 = [96, 76, 101, 67, 89]\n", - "mostrecent120_10 = [97, 103, 98, 64, 94]\n", - "\n", - "retraining_drift = pd.DataFrame(\n", - " {\n", - " \"Baseline\": baseline,\n", - " \"Most Recent \\n(120 days, 1 epoch)\": mostrecent120,\n", - " \"Most Recent\\n (120 days, 10 epochs)\": mostrecent120_10,\n", - " },\n", - ")\n", - "fig, ax = plt.subplots(figsize=(7, 4))\n", - "ax.boxplot(retraining_drift, patch_artist=True)\n", - "ax.set_xticks([1, 2, 3], retraining_drift.columns, rotation=45, fontsize=12)\n", - "ax.set_xlabel(\"Retraining Strategies\", fontsize=12)\n", - "ax.set_ylabel(\"Number of Drift Alarms\", fontsize=12)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "0aceca93-ac9a-41a9-8c1a-a2702acbe587", - "metadata": {}, - "source": [ - "### Drift Threshold" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "52a00d0e-3c6b-4527-a0cb-fd2cc31a02ab", - "metadata": {}, - "outputs": [], - "source": [ - "mostrecent120_10_2 = [50, 44, 40, 51, 61]\n", - "mostrecent120 = [96, 76, 101, 67, 89]\n", - "mostrecent120_10_1 = [121, 150, 123, 139, 131]\n", - "\n", - "retraining_drift = pd.DataFrame(\n", - " {\n", - " \"P-Val=0.01\": mostrecent120_10_2,\n", - " \"P-Val=0.05\": mostrecent120,\n", - " \"P-Val=0.1\": mostrecent120_10_1,\n", - " },\n", - ")\n", - "fig, ax = plt.subplots(figsize=(7, 4))\n", - "ax.boxplot(retraining_drift, patch_artist=True)\n", - "ax.set_xticks([1, 2, 3], retraining_drift.columns, rotation=45, fontsize=12)\n", - "ax.set_xlabel(\"Retraining Strategies\", fontsize=12)\n", - "ax.set_ylabel(\"Number of Drift Alarms\", fontsize=12)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "f972b02a-f2f4-4e84-8891-a1bc650d34b2", - "metadata": {}, - "source": [ - "## Retraining: PPV & Sensitivity" - ] - }, - { - "cell_type": "markdown", - "id": "44d7cbaf-d645-40b3-9e0f-a8a324298c23", - "metadata": {}, - "source": [ - "### Window Size" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d22682c3-5501-44cc-96bf-4099eb9622f1", - "metadata": {}, - "outputs": [], - "source": [ - "types = [\n", - " \"baseline\",\n", - " \"mostrecent30\",\n", - " \"mostrecent60\",\n", - " \"mostrecent120\",\n", - " \"cumulative_1epoch\",\n", - "]\n", - "labels = [\n", - " \"Baseline\",\n", - " \"Most Recent\\n(30days)\",\n", - " \"Most Recent\\n(60 days)\",\n", - " \"Most Recent\\n(120 days)\",\n", - " \"Cumulative\",\n", - "]\n", - "\n", - "drift_sensitivity = []\n", - "drift_ppv = []\n", - "for retraining_type in types:\n", - " for i in range(0, 5):\n", - " res_path = os.path.join(\n", - " PATH,\n", - " SHIFT,\n", - " SHIFT + \"_\" + retraining_type + \"_retraining_update.npy\",\n", - " )\n", - " cum = np.load(res_path, allow_pickle=True)[i]\n", - " drift_sensitivity.append(np.mean(cum[\"performance\"][\"rec1\"]))\n", - " drift_ppv.append(np.mean(cum[\"performance\"][\"prec1\"]))\n", - " # drift_sensitivity.append(np.mean(cum['performance']['rec1'][[i for i,v\n", - " # in enumerate(cum['pval']) if v < 0.05]]))\n", - " # drift_ppv.append(np.mean(cum['performance']['prec1'][[i for i,v\n", - " # in enumerate(cum['pval']) if v < 0.05]]))\n", - "\n", - "retraining_drift = pd.DataFrame(\n", - " {\n", - " \"Retraining Strategy\": np.repeat(types, 5),\n", - " \"PPV\": drift_ppv,\n", - " \"Sensitivity\": drift_sensitivity,\n", - " },\n", - ")\n", - "\n", - "fig, axs = plt.subplots(1, 2, figsize=(14, 4))\n", - "for j, variable in enumerate([\"PPV\", \"Sensitivity\"]):\n", - " for i, grp in enumerate(retraining_drift.groupby(\"Retraining Strategy\")):\n", - " axs[j].boxplot(\n", - " x=variable,\n", - " data=grp[1],\n", - " positions=[i],\n", - " widths=0.4,\n", - " patch_artist=True,\n", - " )\n", - " axs[j].set_xticks(range(0, len(types)), labels, rotation=45, fontsize=12)\n", - " axs[j].set_ylabel(variable, fontsize=12)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "68e011a5-92d0-4ebf-adca-83ae40c32f07", - "metadata": {}, - "source": [ - "### Number of Epochs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "395c27bf-4553-473c-8fe4-f95e99a03ad4", - "metadata": {}, - "outputs": [], - "source": [ - "types = [\"baseline\", \"mostrecent120\", \"mostrecent120_10epochs\"]\n", - "labels = [\"Baseline\", \"Most Recent\\n(120 days)\", \"Most Recent\\n(120 days, 10 epochs)\"]\n", - "\n", - "drift_sensitivity = []\n", - "drift_ppv = []\n", - "for retraining_type in types:\n", - " for i in range(0, 5):\n", - " res_path = os.path.join(\n", - " PATH,\n", - " SHIFT,\n", - " SHIFT + \"_\" + retraining_type + \"_retraining_update.npy\",\n", - " )\n", - " cum = np.load(res_path, allow_pickle=True)[i]\n", - " drift_sensitivity.append(np.mean(cum[\"performance\"][\"rec1\"]))\n", - " drift_ppv.append(np.mean(cum[\"performance\"][\"prec1\"]))\n", - " # drift_sensitivity.append(np.mean(cum['performance']['rec1'][[i for i,v\n", - " # in enumerate(cum['pval']) if v < 0.05]]))\n", - " # drift_ppv.append(np.mean(cum['performance']['prec1'][[i for i,v\n", - " # in enumerate(cum['pval']) if v < 0.05]]))\n", - "\n", - "retraining_drift = pd.DataFrame(\n", - " {\n", - " \"Retraining Strategy\": np.repeat(types, 5),\n", - " \"PPV\": drift_ppv,\n", - " \"Sensitivity\": drift_sensitivity,\n", - " },\n", - ")\n", - "\n", - "fig, axs = plt.subplots(1, 2, figsize=(14, 4))\n", - "for j, variable in enumerate([\"PPV\", \"Sensitivity\"]):\n", - " for i, grp in enumerate(retraining_drift.groupby(\"Retraining Strategy\")):\n", - " axs[j].boxplot(\n", - " x=variable,\n", - " data=grp[1],\n", - " positions=[i],\n", - " widths=0.4,\n", - " patch_artist=True,\n", - " )\n", - " axs[j].set_xticks(range(0, len(types)), labels, rotation=45, fontsize=12)\n", - " axs[j].set_ylabel(variable, fontsize=12)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "97902e7f-d9c4-4c05-81e2-57ba1dde34e0", - "metadata": {}, - "source": [ - "### Drift Threshold" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9c58a8e0-e101-4ecc-a1e8-e04fd366cd60", - "metadata": {}, - "outputs": [], - "source": [ - "types = [\n", - " \"mostrecent120_1epoch_pval0.01\",\n", - " \"mostrecent120\",\n", - " \"mostrecent120_1epoch_pval0.1\",\n", - "]\n", - "labels = [\"P-Val=0.01\", \"P-Val=0.05\", \"P-Val=0.1\"]\n", - "\n", - "drift_sensitivity = []\n", - "drift_ppv = []\n", - "for retraining_type in types:\n", - " for i in range(0, 5):\n", - " res_path = os.path.join(\n", - " PATH,\n", - " SHIFT,\n", - " SHIFT + \"_\" + retraining_type + \"_retraining_update.npy\",\n", - " )\n", - " cum = np.load(res_path, allow_pickle=True)[i]\n", - " drift_sensitivity.append(np.mean(cum[\"performance\"][\"rec1\"]))\n", - " drift_ppv.append(np.mean(cum[\"performance\"][\"prec1\"]))\n", - " # drift_sensitivity.append(np.mean(cum['performance']['rec1'][[i for i,v\n", - " # in enumerate(cum['pval']) if v < 0.05]]))\n", - " # drift_ppv.append(np.mean(cum['performance']['prec1'][[i for i,v\n", - " # in enumerate(cum['pval']) if v < 0.05]]))\n", - "\n", - "retraining_drift = pd.DataFrame(\n", - " {\n", - " \"Retraining Strategy\": np.repeat(types, 5),\n", - " \"PPV\": drift_ppv,\n", - " \"Sensitivity\": drift_sensitivity,\n", - " },\n", - ")\n", - "\n", - "fig, axs = plt.subplots(1, 2, figsize=(14, 4))\n", - "for j, variable in enumerate([\"PPV\", \"Sensitivity\"]):\n", - " for i, grp in enumerate(retraining_drift.groupby(\"Retraining Strategy\")):\n", - " axs[j].boxplot(\n", - " x=variable,\n", - " data=grp[1],\n", - " positions=[i],\n", - " widths=0.4,\n", - " patch_artist=True,\n", - " )\n", - " axs[j].set_xticks(range(0, len(types)), labels, rotation=45, fontsize=12)\n", - " axs[j].set_ylabel(variable, fontsize=12)\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "d92f0a10-92fb-479d-a626-c5a28dbf88b8", - "metadata": {}, - "source": [ - "## Relationship between Performance and Drift P-Value" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "47ea7822-b83a-41ee-9bc4-394eb13b296b", - "metadata": {}, - "outputs": [], - "source": [ - "# calculate Pearson's correlation\n", - "pcorr, pcorr_pval = pearsonr(results[\"prec1\"], results[\"pval\"])\n", - "print(\"Pearsons correlation: %.3f P-Value: %.3f\" % (pcorr, pcorr_pval))\n", - "# calculate spearman's correlation\n", - "scorr, scorr_pval = spearmanr(results[\"prec1\"], results[\"pval\"])\n", - "print(\"Spearmans correlation: %.3f P-Value: %.3f\" % (scorr, scorr_pval))\n", - "# plot\n", - "plt.scatter(results[\"prec1\"], results[\"pval\"])\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "115e6398-411d-4c69-8770-9c76f30f3da5", - "metadata": {}, - "outputs": [], - "source": [ - "# calculate Pearson's correlation\n", - "pcorr, pcorr_pval = pearsonr(results[\"rec1\"], results[\"pval\"])\n", - "print(\"Pearsons correlation: %.3f P-Value: %.3f\" % (pcorr, pcorr_pval))\n", - "# calculate spearman's correlation\n", - "scorr, scorr_pval = spearmanr(results[\"rec1\"], results[\"pval\"])\n", - "print(\"Spearmans correlation: %.3f P-Value: %.3f\" % (scorr, scorr_pval))\n", - "# plot\n", - "plt.scatter(results[\"rec1\"], results[\"pval\"])\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.9.7 ('cyclops-4J2PL5I8-py3.9')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - }, - "vscode": { - "interpreter": { - "hash": "bd2cd438e1c6ddffa3035fc73b17ac5cc0e0ea8897eb8be17cc645c6abf0c8cc" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/nbs/prefect_testing.ipynb b/nbs/prefect_testing.ipynb deleted file mode 100644 index 3d31728d5..000000000 --- a/nbs/prefect_testing.ipynb +++ /dev/null @@ -1,172 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "06d2d260", - "metadata": {}, - "outputs": [], - "source": [ - "import time\n", - "\n", - "import pandas as pd\n", - "\n", - "import cyclops.query.mimiciv as mimic\n", - "from cyclops.query import process as qp\n", - "from cyclops.query.mimiciv import SUBJECT_ID\n", - "from cyclops.workflow.task import join_queries_flow, normalize_events_flow" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4409f166", - "metadata": {}, - "outputs": [], - "source": [ - "events = mimic.events().run(limit=1000)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "84e4a470", - "metadata": {}, - "outputs": [], - "source": [ - "normalize_events_flow(events).result()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c1511ce7", - "metadata": {}, - "outputs": [], - "source": [ - "patients = mimic.patients()\n", - "patient_diagnoses = mimic.patient_diagnoses()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "621ff0a7", - "metadata": {}, - "outputs": [], - "source": [ - "patients_df = patients.run()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a777e864", - "metadata": {}, - "outputs": [], - "source": [ - "patient_diagnoses_df = patient_diagnoses.run()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c9421536", - "metadata": {}, - "outputs": [], - "source": [ - "t = time.time()\n", - "merged = pd.merge(patients_df, patient_diagnoses_df)\n", - "time.time() - t" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d9bd61ef", - "metadata": {}, - "outputs": [], - "source": [ - "# Run queries, join in pandas\n", - "1.500441 + 25.672523 + 3.3360862731933594" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aa3cbe6e", - "metadata": {}, - "outputs": [], - "source": [ - "# Perform join in SQLAlchemy\n", - "39.795473\n", - "\n", - "# Join in SQLAlchemy 12.622509\n", - "# Join in Pandas 3.336086\n", - "39.795473 - (1.500441 + 25.672523)" - ] - }, - { - "cell_type": "markdown", - "id": "6401d703", - "metadata": {}, - "source": [ - "Check out Pandas serializer: https://docs.prefect.io/api/latest/engine/serializers.html\n", - "\n", - "Override their write method for ours? Save directly instead of running/using their save\n", - "\n", - "This should save some time. Especially if we save to .csv. Also, if that's where they're sorting the DataFrame, then this could be a serious time save.\n", - "\n", - "Do this by defining a custom serialize function?\n", - "https://orion-docs.prefect.io/api-ref/prefect/flows/#prefect.flows.Flow.serialize_parameters\n", - "\n", - "https://github.com/PrefectHQ/prefect/blob/master/src/prefect/engine/serializers.py" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bb60010c", - "metadata": {}, - "outputs": [], - "source": [ - "query = qp.Join(patients.query, on=SUBJECT_ID)(patient_diagnoses.query)\n", - "mimic.get_interface(query).run()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6aace91b", - "metadata": {}, - "outputs": [], - "source": [ - "t = time.time()\n", - "join_flow = join_queries_flow(patient_diagnoses, patients, on=[SUBJECT_ID])\n", - "print(time.time() - t)\n", - "join_flow.result()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "cyclops", - "language": "python", - "name": "cyclops" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/pyproject.toml b/pyproject.toml index df739057e..bd25a049a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -188,9 +188,6 @@ extra_checks = true [tool.ruff] include = ["*.py", "pyproject.toml", "*.ipynb"] line-length = 88 -exclude = [ - "nbs", -] [tool.ruff.format] quote-style = "double"