From 13e3e73076c976a6306be41e865d7da484396f24 Mon Sep 17 00:00:00 2001
From: Leyla Kabuli <lakabuli@gmail.com>
Date: Mon, 18 Nov 2024 13:24:05 -0800
Subject: [PATCH] lensless mi experiments with updated api

---
 ...upervised_wiener_deconvolution_per_lens.py | 176 +++++
 ...s_deconvolution_plots_cifar10_figure.ipynb | 591 +++++++++++++++++
 ...n_cifar10_updated_api_reruns_smaller_lr.py | 154 +++++
 lensless_imager/lensless_helpers.py           | 612 ++++++++++++++++++
 4 files changed, 1533 insertions(+)
 create mode 100644 lensless_imager/2024_10_22_sweep_unsupervised_wiener_deconvolution_per_lens.py
 create mode 100644 lensless_imager/2024_10_23_mi_vs_deconvolution_plots_cifar10_figure.ipynb
 create mode 100644 lensless_imager/2024_10_23_pixelcnn_cifar10_updated_api_reruns_smaller_lr.py
 create mode 100644 lensless_imager/lensless_helpers.py

diff --git a/lensless_imager/2024_10_22_sweep_unsupervised_wiener_deconvolution_per_lens.py b/lensless_imager/2024_10_22_sweep_unsupervised_wiener_deconvolution_per_lens.py
new file mode 100644
index 0000000..15fc594
--- /dev/null
+++ b/lensless_imager/2024_10_22_sweep_unsupervised_wiener_deconvolution_per_lens.py
@@ -0,0 +1,176 @@
+# ---
+# jupyter:
+#   jupytext:
+#     text_representation:
+#       extension: .py
+#       format_name: percent
+#       format_version: '1.3'
+#       jupytext_version: 1.15.2
+#   kernelspec:
+#     display_name: infotransformer
+#     language: python
+#     name: python3
+# ---
+
+# %% [markdown]
+# ## Sweeping both unsupervised Wiener Deconvolution and non-unsupervised Wiener Deconvolution with hand-tuned paramete
+#
+# Using a fixed seed (10) for consistency.
+
+# %%
+# %load_ext autoreload
+# %autoreload 2
+
+import os
+from jax import config
+config.update("jax_enable_x64", True)
+import sys 
+sys.path.append('/home/lakabuli/workspace/EncodingInformation/src')
+
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
+os.environ["CUDA_VISIBLE_DEVICES"] = '1'
+from encoding_information.gpu_utils import limit_gpu_memory_growth
+limit_gpu_memory_growth()
+
+
+from cleanplots import *
+import numpy as np
+import tensorflow as tf
+import tensorflow.keras as tfk
+
+from lensless_helpers import *
+from tqdm import tqdm
+
+# %%
+from encoding_information.image_utils import add_noise
+from skimage.restoration import wiener, unsupervised_wiener, richardson_lucy
+import skimage.metrics as skm
+
+# %%
+# load the PSFs
+
+diffuser_psf = load_diffuser_32()
+one_psf = load_single_lens_uniform(32)
+two_psf = load_two_lens_uniform(32)
+three_psf = load_three_lens_uniform(32)
+four_psf = load_four_lens_uniform(32)
+five_psf = load_five_lens_uniform(32)
+aperture_psf = np.copy(diffuser_psf)
+aperture_psf[:5] = 0
+aperture_psf[-5:] = 0
+aperture_psf[:,:5] = 0
+aperture_psf[:,-5:] = 0
+
+
+# %%
+def compute_skm_metrics(gt, recon):
+    # takes in already normalized gt
+    mse = skm.mean_squared_error(gt, recon)
+    psnr = skm.peak_signal_noise_ratio(gt, recon)
+    nmse = skm.normalized_root_mse(gt, recon)
+    ssim = skm.structural_similarity(gt, recon, data_range=1)
+    return mse, psnr, nmse, ssim
+
+
+# %%
+# set seed values for reproducibility
+seed_values_full = np.arange(1, 4)
+
+# set photon properties 
+#mean_photon_count_list = [20, 40, 60, 80, 100, 150, 200, 250, 300]
+mean_photon_count_list = [20, 40, 80, 160, 320]
+
+# set eligible psfs
+
+psf_patterns_use = [one_psf, four_psf, diffuser_psf]
+psf_names_use = ['one', 'four', 'diffuser']
+
+save_dir = '/home/lakabuli/workspace/EncodingInformation/lensless_imager/deconvolutions/'
+
+
+# MI estimator parameters 
+patch_size = 32
+num_patches = 10000
+test_set_size = 1500 
+bs = 500
+max_epochs = 50
+
+seed_value = 10
+
+reg_value_best = 10**-2
+
+# %%
+# data generation process 
+
+for photon_count in mean_photon_count_list:
+    for psf_idx, psf_pattern in enumerate(psf_patterns_use):
+        # load dataset 
+        (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data()
+        data = np.concatenate((x_train, x_test), axis=0)
+        data = data.astype(np.float64)
+        labels = np.concatenate((y_train, y_test), axis=0) # make one big glob of labels. 
+        # convert data to grayscale before converting to photons
+        if len(data.shape) == 4:
+            data = tf.image.rgb_to_grayscale(data).numpy()
+            data = data.squeeze()
+        # convert to photons with mean value of photon_count
+        data /= np.mean(data)
+        data *= photon_count
+        # get maximum value in this data 
+        max_val = np.max(data)
+        # make tiled data 
+        random_data, random_labels = generate_random_tiled_data(data, labels, seed_value)
+        # only keep the middle part of the data 
+        data_padded = np.zeros((data.shape[0], 96, 96))
+        data_padded[:, 32:64, 32:64] = random_data[:, 32:64, 32:64]
+        # save the middle part of the data as the gt for metric computation, include only the test set portion.
+        gt_data = data_padded[:, 32:64, 32:64]
+        gt_data = gt_data[-test_set_size:]
+        # extract the test set before doing convolution 
+        test_data = data_padded[-test_set_size:]
+        # convolve the data 
+        convolved_data = convolved_dataset(psf_pattern, test_data) 
+        convolved_data_noisy = add_noise(convolved_data, seed=seed_value)
+        # output of add_noise is a jax array that's float32, convert to regular numpy array and float64.
+        convolved_data_noisy = np.array(convolved_data_noisy).astype(np.float64)
+
+        # compute metrics using unsupervised wiener deconvolution 
+        mse_psf = []
+        psnr_psf = [] 
+        ssim_psf = []
+        for i in tqdm(range(convolved_data_noisy.shape[0])):
+            recon, _ = unsupervised_wiener(convolved_data_noisy[i] / max_val, psf_pattern)
+            recon = recon[17:49, 17:49] #this is the crop window to look at
+            mse = skm.mean_squared_error(gt_data[i] / max_val, recon)
+            psnr = skm.peak_signal_noise_ratio(gt_data[i] / max_val, recon)
+            ssim = skm.structural_similarity(gt_data[i] / max_val, recon, data_range=1)
+            mse_psf.append(mse)
+            psnr_psf.append(psnr)
+            ssim_psf.append(ssim)
+            
+        print('PSF: {}, Mean MSE: {}, Mean PSNR: {}, Mean SSIM: {}'.format(psf_names_use[psf_idx], np.mean(mse_psf), np.mean(psnr_psf), np.mean(ssim_psf)))
+        np.save(save_dir + 'unsupervised_wiener_recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names_use[psf_idx]), [mse_psf, psnr_psf, ssim_psf])
+
+        # repeat with regular deconvolution
+        mse_psf = []
+        psnr_psf = []
+        ssim_psf = []
+        for i in tqdm(range(convolved_data_noisy.shape[0])): 
+            recon = wiener(convolved_data_noisy[i] / max_val, psf_pattern, reg_value_best)
+            recon = recon[17:49, 17:49] #this is the crop window to look at
+            mse = skm.mean_squared_error(gt_data[i] / max_val, recon)
+            psnr = skm.peak_signal_noise_ratio(gt_data[i] / max_val, recon)
+            ssim = skm.structural_similarity(gt_data[i] / max_val, recon, data_range=1)
+            mse_psf.append(mse)
+            psnr_psf.append(psnr)
+            ssim_psf.append(ssim)
+        print('PSF: {}, Mean MSE: {}, Mean PSNR: {}, Mean SSIM: {}'.format(psf_names_use[psf_idx], np.mean(mse_psf), np.mean(psnr_psf), np.mean(ssim_psf)))
+        np.save(save_dir + 'regular_wiener_recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names_use[psf_idx]), [mse_psf, psnr_psf, ssim_psf])
+
+
+
+
+
+# %%
+
+
diff --git a/lensless_imager/2024_10_23_mi_vs_deconvolution_plots_cifar10_figure.ipynb b/lensless_imager/2024_10_23_mi_vs_deconvolution_plots_cifar10_figure.ipynb
new file mode 100644
index 0000000..97c9575
--- /dev/null
+++ b/lensless_imager/2024_10_23_mi_vs_deconvolution_plots_cifar10_figure.ipynb
@@ -0,0 +1,591 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Make the plot for MI and deconvolution relationship for paper figure, 2024/10/23"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%load_ext autoreload \n",
+    "%autoreload 2"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from jax import config\n",
+    "config.update(\"jax_enable_x64\", True)\n",
+    "import numpy as np\n",
+    "\n",
+    "import sys \n",
+    "sys.path.append('/home/lakabuli/workspace/EncodingInformation/src')\n",
+    "from lensless_helpers import *\n",
+    "import os\n",
+    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" \n",
+    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n",
+    "print(os.environ.get('PYTHONPATH'))\n",
+    "from cleanplots import * "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "seed_value = 10\n",
+    "\n",
+    "# set photon properties \n",
+    "bias = 10 # in photons\n",
+    "mean_photon_count_list = [20, 40, 80, 160, 320]\n",
+    "max_photon_count = mean_photon_count_list[-1]\n",
+    "\n",
+    "# set eligible psfs\n",
+    "\n",
+    "psf_names = ['one', 'four', 'diffuser']\n",
+    "\n",
+    "# MI estimator parameters \n",
+    "patch_size = 32\n",
+    "num_patches = 10000\n",
+    "val_set_size = 1000\n",
+    "test_set_size = 1500\n",
+    "\n",
+    "mi_dir = '/home/lakabuli/workspace/EncodingInformation/lensless_imager/mi_estimates_smaller_lr/'\n",
+    "recon_dir = '/home/lakabuli/workspace/EncodingInformation/lensless_imager/deconvolutions/'"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Load MI data and make plots of it\n",
+    "\n",
+    "The plot has essentially invisible error bars. No outlier issues"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from cleanplots import *\n",
+    "get_color_cycle()[0]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fig, ax = plt.subplots(1, 1, figsize=(8, 6))\n",
+    "mis_across_psfs = []\n",
+    "lowers_across_psfs = []\n",
+    "uppers_across_psfs = []\n",
+    "for psf_name in psf_names:\n",
+    "    mis = []\n",
+    "    lowers = []\n",
+    "    uppers = []\n",
+    "    for photon_count in mean_photon_count_list:\n",
+    "        mi_estimates = np.load(mi_dir + 'pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))\n",
+    "        mi_values = mi_estimates[0]\n",
+    "        print(np.max(mi_values) - np.min(mi_values))\n",
+    "        lower_bounds = mi_estimates[1]\n",
+    "        upper_bounds = mi_estimates[2]\n",
+    "        # get index that has smallest mi value across the different model runs.\n",
+    "        min_mi_index = np.argmin(mi_values)\n",
+    "        mis.append(mi_values[min_mi_index])\n",
+    "        lowers.append(lower_bounds[min_mi_index])\n",
+    "        uppers.append(upper_bounds[min_mi_index])\n",
+    "    ax.plot(mean_photon_count_list, mis, label=psf_name) \n",
+    "    ax.fill_between(mean_photon_count_list, lowers, uppers, alpha=0.3)\n",
+    "    mis_across_psfs.append(mis)\n",
+    "    lowers_across_psfs.append(lowers)\n",
+    "    uppers_across_psfs.append(uppers)\n",
+    "plt.legend()\n",
+    "plt.title(\"PixelCNN MI estimates across Photon Count, CIFAR10\")\n",
+    "plt.xlabel(\"Mean Photon Count\")\n",
+    "plt.ylabel(\"Estimated Mutual Information\")\n",
+    "mis_across_psfs = np.array(mis_across_psfs)\n",
+    "lowers_across_psfs = np.array(lowers_across_psfs)\n",
+    "uppers_across_psfs = np.array(uppers_across_psfs)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Load recon data and make plots of it"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "mses_across_psfs = []\n",
+    "mse_lowers_across_psfs = []\n",
+    "mse_uppers_across_psfs = []\n",
+    "psnrs_across_psfs = []\n",
+    "psnr_lowers_across_psfs = []\n",
+    "psnr_uppers_across_psfs = []\n",
+    "ssims_across_psfs = []\n",
+    "ssim_lowers_across_psfs = []\n",
+    "ssim_uppers_across_psfs = []\n",
+    "\n",
+    "for psf_name in psf_names: \n",
+    "    mse_vals = []\n",
+    "    mse_lowers = []\n",
+    "    mse_uppers = []\n",
+    "    psnr_vals = []\n",
+    "    psnr_lowers = []\n",
+    "    psnr_uppers = []\n",
+    "    ssim_vals = []\n",
+    "    ssim_lowers = []\n",
+    "    ssim_uppers = []\n",
+    "    for photon_count in mean_photon_count_list:\n",
+    "        metrics = np.load(recon_dir + 'unsupervised_wiener_recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))\n",
+    "        mse = metrics[0]\n",
+    "        psnr = metrics[1] \n",
+    "        ssim = metrics[2]\n",
+    "        bootstrap_mse, bootstrap_psnr, bootstrap_ssim = compute_bootstraps(mse, psnr, ssim, test_set_size)\n",
+    "        mean_mse, lower_bound_mse, upper_bound_mse = compute_confidence_interval(bootstrap_mse, confidence_interval=0.95)\n",
+    "        mean_psnr, lower_bound_psnr, upper_bound_psnr = compute_confidence_interval(bootstrap_psnr, confidence_interval=0.95)\n",
+    "        mean_ssim, lower_bound_ssim, upper_bound_ssim = compute_confidence_interval(bootstrap_ssim, confidence_interval=0.95)\n",
+    "        mse_vals.append(mean_mse)\n",
+    "        mse_lowers.append(lower_bound_mse)\n",
+    "        mse_uppers.append(upper_bound_mse)\n",
+    "        psnr_vals.append(mean_psnr)\n",
+    "        psnr_lowers.append(lower_bound_psnr)\n",
+    "        psnr_uppers.append(upper_bound_psnr)\n",
+    "        ssim_vals.append(mean_ssim)\n",
+    "        ssim_lowers.append(lower_bound_ssim)\n",
+    "        ssim_uppers.append(upper_bound_ssim)\n",
+    "    mses_across_psfs.append(mse_vals)\n",
+    "    mse_lowers_across_psfs.append(mse_lowers)\n",
+    "    mse_uppers_across_psfs.append(mse_uppers)\n",
+    "    psnrs_across_psfs.append(psnr_vals)\n",
+    "    psnr_lowers_across_psfs.append(psnr_lowers)\n",
+    "    psnr_uppers_across_psfs.append(psnr_uppers)\n",
+    "    ssims_across_psfs.append(ssim_vals)\n",
+    "    ssim_lowers_across_psfs.append(ssim_lowers)\n",
+    "    ssim_uppers_across_psfs.append(ssim_uppers)\n",
+    "mses_across_psfs = np.array(mses_across_psfs)\n",
+    "mse_lowers_across_psfs = np.array(mse_lowers_across_psfs)\n",
+    "mse_uppers_across_psfs = np.array(mse_uppers_across_psfs)\n",
+    "psnrs_across_psfs = np.array(psnrs_across_psfs)\n",
+    "psnr_lowers_across_psfs = np.array(psnr_lowers_across_psfs)\n",
+    "psnr_uppers_across_psfs = np.array(psnr_uppers_across_psfs)\n",
+    "ssims_across_psfs = np.array(ssims_across_psfs)\n",
+    "ssim_lowers_across_psfs = np.array(ssim_lowers_across_psfs)\n",
+    "ssim_uppers_across_psfs = np.array(ssim_uppers_across_psfs)\n",
+    "plt.figure(figsize=(20, 5))\n",
+    "plt.subplot(1, 3, 1)\n",
+    "for i in range(len(psf_names)):\n",
+    "    plt.plot(mean_photon_count_list, mses_across_psfs[i], label=psf_names[i])\n",
+    "    plt.fill_between(mean_photon_count_list, mse_lowers_across_psfs[i], mse_uppers_across_psfs[i], alpha=0.5)\n",
+    "plt.title(\"MSE\")\n",
+    "plt.legend()\n",
+    "plt.subplot(1, 3, 2)\n",
+    "for i in range(len(psf_names)):\n",
+    "    plt.plot(mean_photon_count_list, psnrs_across_psfs[i], label=psf_names[i])\n",
+    "    plt.fill_between(mean_photon_count_list, psnr_lowers_across_psfs[i], psnr_uppers_across_psfs[i], alpha=0.5)\n",
+    "plt.title(\"PSNR\")\n",
+    "plt.subplot(1, 3, 3)\n",
+    "for i in range(len(psf_names)):\n",
+    "    plt.plot(mean_photon_count_list, ssims_across_psfs[i], label=psf_names[i])\n",
+    "    plt.fill_between(mean_photon_count_list, ssim_lowers_across_psfs[i], ssim_uppers_across_psfs[i], alpha=0.5)\n",
+    "plt.title(\"SSIM\")\n",
+    "plt.legend()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Make figures, omitting error bars since smaller than marker size and reverting to circular markers"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Setup"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def marker_for_psf(psf_name):\n",
+    "    if psf_name =='one':\n",
+    "        marker = 'o'\n",
+    "    elif psf_name == 'four':\n",
+    "        marker = 'o'\n",
+    "        #marker = 's' \n",
+    "    elif psf_name == 'diffuser':\n",
+    "        #marker = '*'\n",
+    "        marker = 'o'\n",
+    "    elif psf_name == 'uc':\n",
+    "        marker = 'x'\n",
+    "    elif psf_name =='two':\n",
+    "        marker = 'd'\n",
+    "    elif psf_name == 'three':\n",
+    "        marker = 'v'\n",
+    "    elif psf_name == 'five':\n",
+    "        marker = 'p'\n",
+    "    elif psf_name == 'aperture':\n",
+    "        marker = 'P'\n",
+    "    return marker"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Choose a base colormap\n",
+    "base_colormap = plt.get_cmap('inferno')\n",
+    "# Define the start and end points--used so that high values aren't too light against white background\n",
+    "start, end = 0, 0.88 # making end point 0.8\n",
+    "from matplotlib.colors import LinearSegmentedColormap\n",
+    "# Create a new colormap from the portion of the original colormap\n",
+    "colormap = LinearSegmentedColormap.from_list(\n",
+    "    'trunc({n},{a:.2f},{b:.2f})'.format(n=base_colormap.name, a=start, b=end),\n",
+    "    base_colormap(np.linspace(start, end, 256))\n",
+    ")\n",
+    "\n",
+    "min_photons_per_pixel =  min(mean_photon_count_list)\n",
+    "max_photons_per_pixel =  max(mean_photon_count_list)\n",
+    "\n",
+    "min_log_photons = np.log(min_photons_per_pixel)\n",
+    "max_log_photons = np.log(max_photons_per_pixel)\n",
+    "\n",
+    "def color_for_photon_level(photons_per_pixel):\n",
+    "    log_photons = np.log(photons_per_pixel)\n",
+    "    return colormap((log_photons - min_log_photons) / (max_log_photons - min_log_photons) )"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# old format for selecting target indices, now not used much\n",
+    "metric_type = 1 # 0 for MSE, 1 for PSNR \n",
+    "valid_psfs = [0, 1, 2]\n",
+    "valid_photon_counts = [20, 40, 80, 160, 320]\n",
+    "psf_names = [psf_names[i] for i in valid_psfs]\n",
+    "print(psf_names)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "mse_error_lower = np.abs(mses_across_psfs - mse_lowers_across_psfs)\n",
+    "mse_error_upper = np.abs(mse_uppers_across_psfs - mses_across_psfs)\n",
+    "psnr_error_lower = np.abs(psnrs_across_psfs - psnr_lowers_across_psfs)\n",
+    "psnr_error_upper = np.abs(psnr_uppers_across_psfs - psnrs_across_psfs)\n",
+    "mi_error_lower = np.abs(mis_across_psfs - lowers_across_psfs)\n",
+    "mi_error_upper = np.abs(uppers_across_psfs - mis_across_psfs)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fig, ax = plt.subplots(1, 1, figsize=(6, 5))\n",
+    "for psf_idx, psf_name in enumerate(psf_names):\n",
+    "    # plot all of the points here. \n",
+    "    mi_means_across_photons = []\n",
+    "    recon_means_across_photons = []\n",
+    "    for photon_idx, photon_count in enumerate(mean_photon_count_list):\n",
+    "        color = color_for_photon_level(photon_count) \n",
+    "        mi_value = mis_across_psfs[psf_idx][photon_idx] \n",
+    "        recon_value = mses_across_psfs[psf_idx][photon_idx] \n",
+    "        ax.scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n",
+    "        # add to lists to track later \n",
+    "        mi_means_across_photons.append(mi_value)\n",
+    "        recon_means_across_photons.append(recon_value)\n",
+    "    #ax.errorbar(mis_across_psfs[psf_idx], mses_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[mse_error_lower[psf_idx], mse_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n",
+    "    \n",
+    "    mi_means_across_photons = np.array(mi_means_across_photons)\n",
+    "    recon_means_across_photons = np.array(recon_means_across_photons)\n",
+    "    ax.plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n",
+    "ax.set_xlabel(\"Mutual Information (bits per pixel)\")\n",
+    "ax.set_ylabel(\"Mean Squared Error\")\n",
+    "clear_spines(ax)\n",
+    "\n",
+    "\n",
+    "# legend\n",
+    "# ax.scatter([], [], color='k', marker='o', label='One Lens')\n",
+    "# ax.scatter([], [], color='k', marker='s', label='Four Lens')\n",
+    "# ax.scatter([], [], color='k', marker='*', label='Diffuser')\n",
+    "\n",
+    "ax.legend(loc='upper right', frameon=True)\n",
+    "ax.set_xlim([0, None])\n",
+    "\n",
+    "\n",
+    "\n",
+    "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n",
+    "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n",
+    "sm.set_array([])\n",
+    "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n",
+    "# set tick labels\n",
+    "cbar.ax.set_yticklabels(valid_photon_counts)\n",
+    "\n",
+    "\n",
+    "cbar.set_label('Photons per pixel')\n",
+    "\n",
+    "#plt.savefig('mse_vs_MI_with_confidence_intervals_log_photons.pdf', bbox_inches='tight', transparent=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fig, ax = plt.subplots(1, 1, figsize=(6, 5))\n",
+    "for psf_idx, psf_name in enumerate(psf_names):\n",
+    "    # plot all of the points here. \n",
+    "    mi_means_across_photons = []\n",
+    "    recon_means_across_photons = []\n",
+    "    for photon_idx, photon_count in enumerate(mean_photon_count_list):\n",
+    "        color = color_for_photon_level(photon_count) \n",
+    "        mi_value = mis_across_psfs[psf_idx][photon_idx] \n",
+    "        recon_value = psnrs_across_psfs[psf_idx][photon_idx] \n",
+    "        ax.scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n",
+    "        # add to lists to track later \n",
+    "        mi_means_across_photons.append(mi_value)\n",
+    "        recon_means_across_photons.append(recon_value)\n",
+    "    #ax.errorbar(mis_across_psfs[psf_idx], psnrs_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[psnr_error_lower[psf_idx], psnr_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n",
+    "    \n",
+    "    mi_means_across_photons = np.array(mi_means_across_photons)\n",
+    "    recon_means_across_photons = np.array(recon_means_across_photons)\n",
+    "    ax.plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n",
+    "ax.set_xlabel(\"Mutual Information (bits per pixel)\")\n",
+    "ax.set_ylabel(\"Peak Signal-to-Noise Ratio (dB)\")\n",
+    "clear_spines(ax)\n",
+    "\n",
+    "\n",
+    "# legend\n",
+    "# ax.scatter([], [], color='k', marker='o', label='One Lens')\n",
+    "# ax.scatter([], [], color='k', marker='s', label='Four Lens')\n",
+    "# ax.scatter([], [], color='k', marker='*', label='Diffuser')\n",
+    "\n",
+    "ax.legend(loc='lower right', frameon=True)\n",
+    "ax.set_xlim([0, None])\n",
+    "\n",
+    "\n",
+    "\n",
+    "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n",
+    "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n",
+    "sm.set_array([])\n",
+    "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n",
+    "# set tick labels\n",
+    "cbar.ax.set_yticklabels(valid_photon_counts)\n",
+    "\n",
+    "\n",
+    "cbar.set_label('Photons per pixel')\n",
+    "\n",
+    "#plt.savefig('psnr_vs_MI_with_confidence_intervals_log_photons.pdf', bbox_inches='tight', transparent=True)\n",
+    "    "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fig, ax = plt.subplots(1, 1, figsize=(6, 5))\n",
+    "for psf_idx, psf_name in enumerate(psf_names):\n",
+    "    # plot all of the points here. \n",
+    "    mi_means_across_photons = []\n",
+    "    recon_means_across_photons = []\n",
+    "    for photon_idx, photon_count in enumerate(mean_photon_count_list):\n",
+    "        color = color_for_photon_level(photon_count) \n",
+    "        mi_value = mis_across_psfs[psf_idx][photon_idx] \n",
+    "        recon_value = ssims_across_psfs[psf_idx][photon_idx] \n",
+    "        ax.scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n",
+    "        # add to lists to track later \n",
+    "        mi_means_across_photons.append(mi_value)\n",
+    "        recon_means_across_photons.append(recon_value)\n",
+    "    #ax.errorbar(mis_across_psfs[psf_idx], ssims_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[ssim_error_lower[psf_idx], ssim_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n",
+    "    \n",
+    "    mi_means_across_photons = np.array(mi_means_across_photons)\n",
+    "    recon_means_across_photons = np.array(recon_means_across_photons)\n",
+    "    ax.plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n",
+    "ax.set_xlabel(\"Mutual Information (bits per pixel)\")\n",
+    "ax.set_ylabel(\"Structural Similarity Index Measure (SSIM)\")\n",
+    "clear_spines(ax)\n",
+    "\n",
+    "\n",
+    "# legend\n",
+    "# ax.scatter([], [], color='k', marker='o', label='One Lens')\n",
+    "# ax.scatter([], [], color='k', marker='s', label='Four Lens')\n",
+    "# ax.scatter([], [], color='k', marker='*', label='Diffuser')\n",
+    "\n",
+    "ax.legend(loc='lower right', frameon=True)\n",
+    "ax.set_xlim([0, None])\n",
+    "\n",
+    "\n",
+    "\n",
+    "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n",
+    "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n",
+    "sm.set_array([])\n",
+    "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n",
+    "# set tick labels\n",
+    "cbar.ax.set_yticklabels(valid_photon_counts)\n",
+    "\n",
+    "\n",
+    "cbar.set_label('Photons per pixel')\n",
+    "\n",
+    "#plt.savefig('ssim_vs_MI_with_confidence_intervals_log_photons.pdf', bbox_inches='tight', transparent=True)\n",
+    "    "
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Put all 3 into one figure"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import glob\n",
+    "import numpy as np\n",
+    "import matplotlib.pyplot as plt\n",
+    "from cleanplots import *\n",
+    "from matplotlib.ticker import ScalarFormatter\n",
+    "\n",
+    "figs, axs = plt.subplots(1, 3, figsize=(12, 4), sharex=True)\n",
+    "\n",
+    "\n",
+    "for psf_idx, psf_name in enumerate(psf_names):\n",
+    "    # plot all of the points here. \n",
+    "    mi_means_across_photons = []\n",
+    "    recon_means_across_photons = []\n",
+    "    for photon_idx, photon_count in enumerate(mean_photon_count_list):\n",
+    "        color = color_for_photon_level(photon_count) \n",
+    "        mi_value = mis_across_psfs[psf_idx][photon_idx] \n",
+    "        recon_value = mses_across_psfs[psf_idx][photon_idx] \n",
+    "        axs[0].scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n",
+    "        # add to lists to track later \n",
+    "        mi_means_across_photons.append(mi_value)\n",
+    "        recon_means_across_photons.append(recon_value)\n",
+    "    #ax.errorbar(mis_across_psfs[psf_idx], mses_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[mse_error_lower[psf_idx], mse_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n",
+    "    \n",
+    "    mi_means_across_photons = np.array(mi_means_across_photons)\n",
+    "    recon_means_across_photons = np.array(recon_means_across_photons)\n",
+    "    axs[0].plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n",
+    "#axs[0].set_xlabel(\"Mutual Information (bits per pixel)\")\n",
+    "axs[0].set_title(\"Mean Squared Error\")\n",
+    "clear_spines(axs[0])\n",
+    "\n",
+    "for psf_idx, psf_name in enumerate(psf_names):\n",
+    "    # plot all of the points here. \n",
+    "    mi_means_across_photons = []\n",
+    "    recon_means_across_photons = []\n",
+    "    for photon_idx, photon_count in enumerate(mean_photon_count_list):\n",
+    "        color = color_for_photon_level(photon_count) \n",
+    "        mi_value = mis_across_psfs[psf_idx][photon_idx] \n",
+    "        recon_value = ssims_across_psfs[psf_idx][photon_idx] \n",
+    "        axs[1].scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n",
+    "        # add to lists to track later \n",
+    "        mi_means_across_photons.append(mi_value)\n",
+    "        recon_means_across_photons.append(recon_value)\n",
+    "    #ax.errorbar(mis_across_psfs[psf_idx], ssims_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[ssim_error_lower[psf_idx], ssim_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n",
+    "    \n",
+    "    mi_means_across_photons = np.array(mi_means_across_photons)\n",
+    "    recon_means_across_photons = np.array(recon_means_across_photons)\n",
+    "    axs[1].plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n",
+    "axs[1].set_xlabel(\"Mutual Information (bits per pixel)\")\n",
+    "axs[1].set_title(\"Structural Similarity Index Measure (SSIM)\")\n",
+    "clear_spines(axs[1])\n",
+    "\n",
+    "for psf_idx, psf_name in enumerate(psf_names):\n",
+    "    # plot all of the points here. \n",
+    "    mi_means_across_photons = []\n",
+    "    recon_means_across_photons = []\n",
+    "    for photon_idx, photon_count in enumerate(mean_photon_count_list):\n",
+    "        color = color_for_photon_level(photon_count) \n",
+    "        mi_value = mis_across_psfs[psf_idx][photon_idx] \n",
+    "        recon_value = psnrs_across_psfs[psf_idx][photon_idx] \n",
+    "        axs[2].scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n",
+    "        # add to lists to track later \n",
+    "        mi_means_across_photons.append(mi_value)\n",
+    "        recon_means_across_photons.append(recon_value)\n",
+    "    #ax.errorbar(mis_across_psfs[psf_idx], psnrs_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[psnr_error_lower[psf_idx], psnr_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n",
+    "    \n",
+    "    mi_means_across_photons = np.array(mi_means_across_photons)\n",
+    "    recon_means_across_photons = np.array(recon_means_across_photons)\n",
+    "    axs[2].plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n",
+    "#axs[2].set_xlabel(\"Mutual Information (bits per pixel)\")\n",
+    "axs[2].set_title(\"Peak Signal-to-Noise Ratio (dB)\")\n",
+    "clear_spines(axs[2])\n",
+    "\n",
+    "# norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n",
+    "# sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n",
+    "# sm.set_array([])\n",
+    "# cbar = plt.colorbar(sm, ax=axs[2], ticks=(np.log(valid_photon_counts)))\n",
+    "# # set tick labels\n",
+    "# cbar.ax.set_yticklabels(valid_photon_counts)\n",
+    "\n",
+    "\n",
+    "# cbar.set_label('Photons per pixel')\n",
+    "\n",
+    "#plt.savefig(\"metrics_vs_MI_with_confidence_intervals_log_photons.pdf\", bbox_inches='tight', transparent=True)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "infotransformer",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.14"
+  },
+  "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/lensless_imager/2024_10_23_pixelcnn_cifar10_updated_api_reruns_smaller_lr.py b/lensless_imager/2024_10_23_pixelcnn_cifar10_updated_api_reruns_smaller_lr.py
new file mode 100644
index 0000000..bd9148a
--- /dev/null
+++ b/lensless_imager/2024_10_23_pixelcnn_cifar10_updated_api_reruns_smaller_lr.py
@@ -0,0 +1,154 @@
+# ---
+# jupyter:
+#   jupytext:
+#     text_representation:
+#       extension: .py
+#       format_name: percent
+#       format_version: '1.3'
+#       jupytext_version: 1.15.2
+#   kernelspec:
+#     display_name: infotransformer
+#     language: python
+#     name: python3
+# ---
+
+# %%
+# %load_ext autoreload
+# %autoreload 2
+
+# Final MI estimation script for lensless imager, used in paper.
+
+import os
+from jax import config
+config.update("jax_enable_x64", True)
+import sys 
+sys.path.append('/home/lakabuli/workspace/EncodingInformation/src')
+
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
+os.environ["CUDA_VISIBLE_DEVICES"] = '0'
+from encoding_information.gpu_utils import limit_gpu_memory_growth
+limit_gpu_memory_growth()
+
+from cleanplots import *
+import jax.numpy as np
+import numpy as onp
+import tensorflow as tf
+import tensorflow.keras as tfk
+
+
+from lensless_helpers import *
+
+# %%
+from encoding_information import extract_patches
+from encoding_information.models import PixelCNN
+from encoding_information.plot_utils import plot_samples
+from encoding_information.models import PoissonNoiseModel
+from encoding_information.image_utils import add_noise
+from encoding_information import estimate_information
+
+# %% [markdown]
+# ### Sweep Photon Count and Diffusers
+
+# %%
+diffuser_psf = load_diffuser_32()
+one_psf = load_single_lens_uniform(32)
+two_psf = load_two_lens_uniform(32)
+three_psf = load_three_lens_uniform(32)
+four_psf = load_four_lens_uniform(32)
+five_psf = load_five_lens_uniform(32)
+
+# %%
+# set seed values for reproducibility
+seed_values_full = np.arange(1, 5)
+
+# set photon properties 
+bias = 10 # in photons
+mean_photon_count_list = [20, 40, 80, 160, 320]
+
+# set eligible psfs
+
+psf_patterns = [diffuser_psf, four_psf, one_psf] 
+psf_names = ['diffuser', 'four', 'one']
+
+# MI estimator parameters 
+patch_size = 32
+num_patches = 10000
+val_set_size = 1000
+test_set_size = 1500 
+num_samples = 8
+learning_rate = 1e-3  # using 5x iterations per epoch, using smaller lr, and using less patience since it should be a smoother curve. 
+num_iters_per_epoch = 500
+patience_val = 20
+
+
+save_dir = '/home/lakabuli/workspace/EncodingInformation/lensless_imager/mi_estimates_smaller_lr/'
+
+
+# %%
+for photon_count in mean_photon_count_list:
+    for index, psf_pattern in enumerate(psf_patterns):
+        val_loss_log = []
+        mi_estimates = []
+        lower_bounds = []
+        upper_bounds = []
+        for seed_value in seed_values_full:
+            # load dataset 
+            (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data() 
+            data = onp.concatenate((x_train, x_test), axis=0)
+            labels = np.concatenate((y_train, y_test), axis=0)
+            data = data.astype(np.float32)
+            # convert data to grayscale before converting to photons 
+            if len(data.shape) == 4:
+                data = tf.image.rgb_to_grayscale(data).numpy()
+                data = data.squeeze()
+            # convert to photons with mean value of photon_count 
+            data /= onp.mean(data)
+            data *= photon_count 
+            # make tiled data
+            random_data, random_labels = generate_random_tiled_data(data, labels, seed_value)
+
+            if psf_pattern is None:
+                start_idx = data.shape[-1] // 2 
+                end_idx = data.shape[-1] // 2 - 1
+                psf_data = random_data[:, start_idx:-end_idx, start_idx:-end_idx]
+            else:
+                psf_data = convolved_dataset(psf_pattern, random_data) 
+            # add small bias to data 
+            psf_data += bias 
+            # make patches for training and testing splits, random patching 
+            patches = extract_patches(psf_data[:-test_set_size], patch_size=patch_size, num_patches=num_patches, seed=seed_value, verbose=True)
+            test_patches = extract_patches(psf_data[-test_set_size:], patch_size=patch_size, num_patches=test_set_size, seed=seed_value, verbose=True)  
+            # put all the clean patches together for use in MI estimatino function later
+            full_clean_patches = onp.concatenate([patches, test_patches])
+            # add noise to both sets 
+            patches_noisy = add_noise(patches, seed=seed_value)
+            test_patches_noisy = add_noise(test_patches, seed=seed_value)
+
+            # initialize pixelcnn 
+            pixel_cnn = PixelCNN() 
+            # fit pixelcnn to noisy patches. defaults to 10% val samples which will be 1k as desired. 
+            # using smaller lr this time and adding seeding, letting it go for full training time.
+            val_loss_history = pixel_cnn.fit(patches_noisy, seed=seed_value, learning_rate=learning_rate, do_lr_decay=False, steps_per_epoch=num_iters_per_epoch, patience=patience_val)
+            # generate samples, not necessary for MI sweeps
+            # pixel_cnn_samples = pixel_cnn.generate_samples(num_samples=num_samples)
+            # # visualize samples
+            # plot_samples([pixel_cnn_samples], test_patches, model_names=['PixelCNN'])
+
+            # instantiate noise model
+            noise_model = PoissonNoiseModel()
+            # estimate information using the fit pixelcnn and noise model, with clean data
+            pixel_cnn_info, pixel_cnn_lower_bound, pixel_cnn_upper_bound = estimate_information(pixel_cnn, noise_model, patches_noisy, 
+                                                                                                test_patches_noisy, clean_data=full_clean_patches, 
+                                                                                                confidence_interval=0.95)
+            print("PixelCNN estimated information: ", pixel_cnn_info)
+            print("PixelCNN lower bound: ", pixel_cnn_lower_bound)
+            print("PixelCNN upper bound: ", pixel_cnn_upper_bound)
+            # append results to lists
+            val_loss_log.append(val_loss_history)
+            mi_estimates.append(pixel_cnn_info)
+            lower_bounds.append(pixel_cnn_lower_bound)
+            upper_bounds.append(pixel_cnn_upper_bound)
+            np.save(save_dir + 'pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array([mi_estimates, lower_bounds, upper_bounds]))
+            np.save(save_dir + 'pixelcnn_val_loss_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(val_loss_log, dtype=object))
+        np.save(save_dir + 'pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array([mi_estimates, lower_bounds, upper_bounds]))
+        np.save(save_dir + 'pixelcnn_val_loss_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(val_loss_log, dtype=object))
diff --git a/lensless_imager/lensless_helpers.py b/lensless_imager/lensless_helpers.py
new file mode 100644
index 0000000..0693b28
--- /dev/null
+++ b/lensless_imager/lensless_helpers.py
@@ -0,0 +1,612 @@
+import numpy as np # use regular numpy for now, simpler
+import scipy
+from tqdm import tqdm
+# import tensorflow as tf
+# import tensorflow.keras as tfk
+import gc
+import warnings
+
+import skimage
+import skimage.io
+from skimage.transform import resize
+
+# from tensorflow.keras.optimizers import SGD
+
+def tile_9_images(data_set):
+    # takes 9 images and forms a tiled image
+    assert len(data_set) == 9
+    return np.block([[data_set[0], data_set[1], data_set[2]],[data_set[3], data_set[4], data_set[5]],[data_set[6], data_set[7], data_set[8]]])
+
+def generate_random_tiled_data(x_set, y_set, seed_value=-1):
+    # takes a set of images and labels and returns a set of tiled images and corresponding labels
+    # the size of the output should be 3x the size of the input
+    vert_shape = x_set.shape[1] * 3
+    horiz_shape = x_set.shape[2] * 3
+    random_data = np.zeros((x_set.shape[0], vert_shape, horiz_shape)) # for mnist this was 84 x 84
+    random_labels = np.zeros((y_set.shape[0], 1))
+    if seed_value==-1:
+        np.random.seed()
+    else: 
+        np.random.seed(seed_value)
+    for i in range(x_set.shape[0]):
+        img_items = np.random.choice(x_set.shape[0], size=9, replace=True)
+        data_set = x_set[img_items]
+        random_labels[i] = y_set[img_items[4]]
+        random_data[i] = tile_9_images(data_set)
+    return random_data, random_labels
+
+def generate_repeated_tiled_data(x_set, y_set):
+    # takes set of images and labels and returns a set of repeated tiled images and corresponding labels, no randomness
+    # the size of the output is 3x the size of the input, this essentially is a wrapper for np.tile
+    repeated_data = np.tile(x_set, (1, 3, 3))
+    repeated_labels = y_set # the labels are just what they were
+    return repeated_data, repeated_labels
+
+def convolved_dataset(psf, random_tiled_data):
+    # takes a psf and a set of tiled images and returns a set of convolved images, convolved image size is 2n + 1? same size as the random data when it's cropped
+    # tile size is two images worth plus one extra index value
+    vert_shape = psf.shape[0] * 2 + 1
+    horiz_shape = psf.shape[1] * 2 + 1
+    psf_dataset = np.zeros((random_tiled_data.shape[0], vert_shape, horiz_shape)) # 57 x 57 for the case of mnist 28x28 images, 65 x 65 for the cifar 32 x 32 images
+    for i in tqdm(range(random_tiled_data.shape[0])):
+        psf_dataset[i] = scipy.signal.fftconvolve(psf, random_tiled_data[i], mode='valid')
+    return psf_dataset
+
+def compute_entropy(eigenvalues):
+    sum_log_evs = np.sum(np.log2(eigenvalues))
+    D = eigenvalues.shape[0]
+    gaussian_entropy = 0.5 * (sum_log_evs + D * np.log2(2 * np.pi * np.e))
+    return gaussian_entropy
+
+def add_shot_noise(photon_scaled_images, photon_fraction=None, photons_per_pixel=None, assume_noiseless=True, seed_value=-1):
+    #adapted from henry, also uses a seed though
+    if seed_value==-1:
+        np.random.seed()
+    else: 
+        np.random.seed(seed_value)
+    
+    # check all pixels greater than 0
+    if np.any(photon_scaled_images < 0):
+        #warning about negative
+        warnings.warn(f"Negative pixel values detected. Clipping to 0.")
+        photon_scaled_images[photon_scaled_images < 0] = 0
+    if photons_per_pixel is not None:
+        if photons_per_pixel > np.mean(photon_scaled_images):
+            warnings.warn(f"photons_per_pixel is greater than actual photon count ({photons_per_pixel}). Clipping to {np.mean(photon_scaled_images)}")
+            photons_per_pixel = np.mean(photon_scaled_images)
+        photon_fraction = photons_per_pixel / np.mean(photon_scaled_images)
+
+    if photon_fraction > 1:
+        warnings.warn(f"photon_fraction is greater than 1 ({photon_fraction}). Clipping to 1.")
+        photon_fraction = 1
+
+    if assume_noiseless:
+        additional_sd = np.sqrt(photon_fraction * photon_scaled_images)
+        if np.any(np.isnan(additional_sd)):
+            warnings.warn('There are nans here')
+            additional_sd[np.isnan(additional_sd)] = 0
+        # something here goes weird for RML
+        # 
+    #else:
+    #    additional_sd = np.sqrt(photon_fraction * photon_scaled_images) - photon_fraction * np.sqrt(photon_scaled_images)
+    simulated_images = photon_scaled_images * photon_fraction + additional_sd * np.random.randn(*photon_scaled_images.shape)
+    positive = np.array(simulated_images)
+    positive[positive < 0] = 0 # cant have negative counts
+    return np.array(positive)
+
+def tf_cast(data):
+    # normalizes data, loads it to a tensorflow array of type float32
+    return tf.cast(data / np.max(data), tf.float32)
+def tf_labels(labels):
+    # loads labels to a tensorflow array of type int64
+    return tf.cast(labels, tf.int64)
+
+
+
+def run_model_simple(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1):
+    if seed_value == -1:
+        seed_val = np.random.randint(10, 1000)
+        tfk.utils.set_random_seed(seed_val)
+    else:
+        tfk.utils.set_random_seed(seed_value)
+
+    model = tfk.models.Sequential()
+    model.add(tfk.layers.Flatten())
+    model.add(tfk.layers.Dense(256, activation='relu'))
+    model.add(tfk.layers.Dense(256, activation='relu'))
+    model.add(tfk.layers.Dense(10, activation='softmax'))
+
+    model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
+    
+    early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option 
+                                        mode="min", patience=5,
+                                        restore_best_weights=True, verbose=1)
+    history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), batch_size=32, epochs=50, callbacks=[early_stop])
+    test_loss, test_acc = model.evaluate(test_data, test_labels)
+    return history, model, test_loss, test_acc
+
+def run_model_cnn(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1):
+    # structure from https://www.kaggle.com/code/cdeotte/how-to-choose-cnn-architecture-mnist
+    if seed_value == -1:
+        seed_val = np.random.randint(10, 1000)
+        tfk.utils.set_random_seed(seed_val)
+    else:
+        tfk.utils.set_random_seed(seed_value)
+
+    model = tfk.models.Sequential()
+    model.add(tfk.layers.Conv2D(64, kernel_size=5, padding='same', activation='relu', input_shape=(57, 57, 1))) #64 and 128 works very slightly better
+    model.add(tfk.layers.MaxPool2D())
+    model.add(tfk.layers.Conv2D(128, kernel_size=5, padding='same', activation='relu'))
+    model.add(tfk.layers.MaxPool2D())
+    #model.add(tfk.layers.Conv2D(64, kernel_size=5, padding='same', activation='relu'))
+    #model.add(tfk.layers.MaxPool2D(padding='same'))
+    model.add(tfk.layers.Flatten())
+
+    #model.add(tfk.layers.Dense(256, activation='relu'))
+    model.add(tfk.layers.Dense(128, activation='relu'))
+    model.add(tfk.layers.Dense(10, activation='softmax'))
+
+    model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
+
+    early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option 
+                                        mode="min", patience=5,
+                                        restore_best_weights=True, verbose=1)
+    history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), epochs=50, batch_size=32, callbacks=[early_stop]) #validation data is not test data
+    test_loss, test_acc = model.evaluate(test_data, test_labels)
+    return history, model, test_loss, test_acc
+
+def seeded_permutation(seed_value, n):
+    # given fixed seed returns permutation order
+    np.random.seed(seed_value)
+    permutation_order = np.random.permutation(n)
+    return permutation_order
+
+def segmented_indices(permutation_order, n, training_fraction, test_fraction):
+    #given permutation order returns indices for each of the three sets
+    training_indices = permutation_order[:int(training_fraction*n)]
+    test_indices = permutation_order[int(training_fraction*n):int((training_fraction+test_fraction)*n)]
+    validation_indices = permutation_order[int((training_fraction+test_fraction)*n):]
+    return training_indices, test_indices, validation_indices
+
+def permute_data(data, labels, seed_value, training_fraction=0.8, test_fraction=0.1): 
+    #validation fraction is implicit, if including a validation set, expect to use the remaining fraction of the data
+    permutation_order = seeded_permutation(seed_value, data.shape[0])
+    training_indices, test_indices, validation_indices = segmented_indices(permutation_order, data.shape[0], training_fraction, test_fraction)
+
+    training_data = data[training_indices]
+    training_labels = labels[training_indices]
+    testing_data = data[test_indices]
+    testing_labels = labels[test_indices]
+    validation_data = data[validation_indices]
+    validation_labels = labels[validation_indices]
+
+    return (training_data, training_labels), (testing_data, testing_labels), (validation_data, validation_labels)
+
+def add_gaussian_noise(data, noise_level, seed_value=-1):
+    if seed_value==-1:
+        np.random.seed()
+    else: 
+        np.random.seed(seed_value)
+    return data + noise_level * np.random.randn(*data.shape)
+
+def confidence_bars(data_array, noise_length, confidence_interval=0.95):
+    # can also use confidence interval 0.9 or 0.99 if want slightly different bounds
+    error_lo = np.percentile(data_array, 100 * (1 - confidence_interval) / 2, axis=1)
+    error_hi = np.percentile(data_array, 100 * (1 - (1 - confidence_interval) / 2), axis=1)
+    mean = np.mean(data_array, axis=1)
+    assert len(error_lo) == len(mean) == len(error_hi) == noise_length
+    return error_lo, error_hi, mean
+
+
+######### This function is very outdated, don't use it!! used to be called test_system use the ones below instead
+#########
+def test_system_old(noise_level, psf_name, model_name, seed_values, data, labels, training_fraction, testing_fraction,  diffuser_region, phlat_region, psf, noise_type, rml_region):
+    # runs the model for the number of seeds given, returns the test accuracy for each seed
+    test_accuracy_list = []
+    for seed_value in seed_values:
+        seed_value = int(seed_value)
+        tfk.backend.clear_session()
+        gc.collect()
+        tfk.utils.set_random_seed(seed_value) # set random seed out here too?
+        training, testing, validation = permute_data(data, labels, seed_value, training_fraction, testing_fraction)
+        x_train, y_train = training
+        x_test, y_test = testing
+        x_validation, y_validation = validation
+
+        random_test_data, random_test_labels = generate_random_tiled_data(x_test, y_test, seed_value)
+        random_train_data, random_train_labels = generate_random_tiled_data(x_train, y_train, seed_value)
+        random_valid_data, random_valid_labels = generate_random_tiled_data(x_validation, y_validation, seed_value)
+
+        if psf_name == 'uc':
+            test_data = random_test_data[:, 14:-13, 14:-13]
+            train_data = random_train_data[:, 14:-13, 14:-13]
+            valid_data = random_valid_data[:, 14:-13, 14:-13]
+        if psf_name == 'psf_4':
+            test_data = convolved_dataset(psf, random_test_data)
+            train_data = convolved_dataset(psf, random_train_data)
+            valid_data = convolved_dataset(psf, random_valid_data)
+        if psf_name == 'diffuser':
+            test_data = convolved_dataset(diffuser_region, random_test_data)
+            train_data = convolved_dataset(diffuser_region, random_train_data)
+            valid_data = convolved_dataset(diffuser_region, random_valid_data)  
+        if psf_name == 'phlat':
+            test_data = convolved_dataset(phlat_region, random_test_data)
+            train_data = convolved_dataset(phlat_region, random_train_data)
+            valid_data = convolved_dataset(phlat_region, random_valid_data)
+        # 6/19/23 added RML option
+        if psf_name == 'rml':
+            test_data = convolved_dataset(rml_region, random_test_data)
+            train_data = convolved_dataset(rml_region, random_train_data)
+            valid_data = convolved_dataset(rml_region, random_valid_data)
+
+        # address any tiny floating point negative values, which only occur in RML data
+        if np.any(test_data < 0):
+            #print('negative values in test data for {} psf'.format(psf_name))
+            test_data[test_data < 0] = 0
+        if np.any(train_data < 0):
+            #print('negative values in train data for {} psf'.format(psf_name))
+            train_data[train_data < 0] = 0
+        if np.any(valid_data < 0):
+            #print('negative values in valid data for {} psf'.format(psf_name))
+            valid_data[valid_data < 0] = 0
+            
+
+        # additive gaussian noise, add noise after convolving, fixed 5/15/2023
+        if noise_type == 'gaussian':
+            test_data = add_gaussian_noise(test_data, noise_level, seed_value)
+            train_data = add_gaussian_noise(train_data, noise_level, seed_value)
+            valid_data = add_gaussian_noise(valid_data, noise_level, seed_value)
+        if noise_type == 'poisson':
+            test_data = add_shot_noise(test_data, photons_per_pixel=noise_level, seed_value=seed_value, assume_noiseless=True)
+            train_data = add_shot_noise(train_data, photons_per_pixel=noise_level, seed_value=seed_value, assume_noiseless=True)
+            valid_data = add_shot_noise(valid_data, photons_per_pixel=noise_level, seed_value=seed_value, assume_noiseless=True)
+
+        train_data, test_data, valid_data = tf_cast(train_data), tf_cast(test_data), tf_cast(valid_data)
+        random_train_labels, random_test_labels, random_valid_labels = tf_labels(random_train_labels), tf_labels(random_test_labels), tf_labels(random_valid_labels)
+        
+        if model_name == 'simple':
+            history, model, test_loss, test_acc = run_model_simple(train_data, random_train_labels, test_data, random_test_labels, valid_data, random_valid_labels, seed_value)
+        if model_name == 'cnn':
+            history, model, test_loss, test_acc = run_model_cnn(train_data, random_train_labels, test_data, random_test_labels, valid_data, random_valid_labels, seed_value)
+        test_accuracy_list.append(test_acc)
+    np.save('classification_results_rml_psf_619/test_accuracy_{}_noise_{}_{}_psf_{}_model.npy'.format(noise_level, noise_type, psf_name, model_name), test_accuracy_list)
+
+ ###### CNN for 32x32 CIFAR10 images 
+    # Originally written 11/14/2023, but then lost in a merge, recopied 1/14/2024
+def run_model_cnn_cifar(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1, max_epochs=50, patience=5):
+    # structure from https://www.kaggle.com/code/cdeotte/how-to-choose-cnn-architecture-mnist
+    # default architecture is 50 epochs and patience 5, but recently some need longer patience
+    if seed_value == -1:
+        seed_val = np.random.randint(10, 1000)
+        tfk.utils.set_random_seed(seed_val)
+    else:
+        tfk.utils.set_random_seed(seed_value)
+    model = tfk.models.Sequential()
+    model.add(tfk.layers.Conv2D(64, kernel_size=5, padding='same', activation='relu', input_shape=(65, 65, 1)))
+    model.add(tfk.layers.MaxPool2D())
+    model.add(tfk.layers.Conv2D(128, kernel_size=5, padding='same', activation='relu'))
+    model.add(tfk.layers.MaxPool2D())
+    #model.add(tfk.layers.Conv2D(64, kernel_size=5, padding='same', activation='relu'))
+    #model.add(tfk.layers.MaxPool2D(padding='same'))
+    model.add(tfk.layers.Flatten())
+
+    #model.add(tfk.layers.Dense(256, activation='relu'))
+    model.add(tfk.layers.Dense(128, activation='relu'))
+    model.add(tfk.layers.Dense(10, activation='softmax'))
+
+    model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
+
+    early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option 
+                                        mode="min", patience=patience,
+                                        restore_best_weights=True, verbose=1)
+    history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), epochs=max_epochs, batch_size=32, callbacks=[early_stop]) #validation data is not test data
+    test_loss, test_acc = model.evaluate(test_data, test_labels)
+    return history, model, test_loss, test_acc
+
+def make_ttv_sets(data, labels, seed_value, training_fraction, testing_fraction):
+    training, testing, validation = permute_data(data, labels, seed_value, training_fraction, testing_fraction)
+    training_data, training_labels = training
+    testing_data, testing_labels = testing
+    validation_data, validation_labels = validation
+    training_data, testing_data, validation_data = tf_cast(training_data), tf_cast(testing_data), tf_cast(validation_data)
+    training_labels, testing_labels, validation_labels = tf_labels(training_labels), tf_labels(testing_labels), tf_labels(validation_labels)
+    return (training_data, training_labels), (testing_data, testing_labels), (validation_data, validation_labels)
+
+def run_network_cifar(data, labels, seed_value, training_fraction, testing_fraction, mode='cnn', max_epochs=50, patience=5):
+    # small modification to be able to run 32x32 image data 
+    training, testing, validation = make_ttv_sets(data, labels, seed_value, training_fraction, testing_fraction)
+    if mode == 'cnn':
+        history, model, test_loss, test_acc = run_model_cnn_cifar(training[0], training[1],
+                                                            testing[0], testing[1],
+                                                            validation[0], validation[1], seed_value, max_epochs, patience)
+    elif mode == 'simple':
+        history, model, test_loss, test_acc = run_model_simple(training[0], training[1], 
+                                                                     testing[0], testing[1],
+                                                                     validation[0], validation[1], seed_value)
+    elif mode == 'new_cnn':
+        history, model, test_loss, test_acc = current_testing_model(training[0], training[1],
+                                                                    testing[0], testing[1], 
+                                                                    validation[0], validation[1], seed_value, max_epochs, patience)
+    elif mode == 'mom_cnn':
+        history, model, test_loss, test_acc = momentum_testing_model(training[0], training[1],
+                                                                    testing[0], testing[1], 
+                                                                    validation[0], validation[1], seed_value, max_epochs, patience)
+    return history, model, test_loss, test_acc  
+
+
+def load_diffuser_psf():
+    diffuser_psf = skimage.io.imread('psfs/diffuser_psf.png')
+    diffuser_psf = diffuser_psf[:,:,1]
+    diffuser_resize = diffuser_psf[200:500, 250:550]
+    diffuser_resize = resize(diffuser_resize, (100, 100), anti_aliasing=True)  #resize(diffuser_psf, (28, 28))
+    diffuser_region = diffuser_resize[:28, :28]
+    diffuser_region /=  np.sum(diffuser_region)
+    return diffuser_region
+
+def load_phlat_psf():
+    phlat_psf = skimage.io.imread('psfs/phlat_psf.png')
+    phlat_psf = phlat_psf[900:2900, 1500:3500, 1]
+    phlat_psf = resize(phlat_psf, (200, 200), anti_aliasing=True)
+    phlat_region = phlat_psf[10:38, 20:48]
+    phlat_region /= np.sum(phlat_region)
+    return phlat_region
+
+def load_4_psf():
+    psf = np.zeros((28, 28))
+    psf[20,20] = 1
+    psf[15, 10] = 1
+    psf[5, 13] = 1
+    psf[23, 6] = 1
+    psf = scipy.ndimage.gaussian_filter(psf, sigma=1)
+    psf /= np.sum(psf)
+    return psf
+
+# 6/9/23 added rml option
+def load_rml_psf():
+    rml_psf = skimage.io.imread('psfs/psf_8holes.png')
+    rml_psf = rml_psf[1000:3000, 1500:3500]
+    rml_psf_resize = resize(rml_psf, (100, 100), anti_aliasing=True)
+    rml_psf_region = rml_psf_resize[40:100, :60]
+    rml_psf_region = resize(rml_psf_region, (28, 28), anti_aliasing=True)
+    rml_psf_region /= np.sum(rml_psf_region)
+    return rml_psf_region
+
+def load_rml_new_psf():
+    rml_psf = skimage.io.imread('psfs/psf_8holes.png')
+    rml_psf = rml_psf[1000:3000, 1500:3500]
+    rml_psf_small = resize(rml_psf, (85, 85), anti_aliasing=True)
+    rml_psf_region = rml_psf_small[52:80, 10:38]
+    rml_psf_region /= np.sum(rml_psf_region)
+    return rml_psf_region 
+
+def load_single_lens():
+    one_lens = np.zeros((28, 28))
+    one_lens[14, 14] = 1
+    one_lens = scipy.ndimage.gaussian_filter(one_lens, sigma=0.8)
+    one_lens /= np.sum(one_lens)
+    return one_lens
+
+def load_two_lens():
+    two_lens = np.zeros((28, 28))
+    two_lens[10, 10] = 1
+    two_lens[20, 20] = 1
+    two_lens = scipy.ndimage.gaussian_filter(two_lens, sigma=0.8)
+    two_lens /= np.sum(two_lens)
+    return two_lens
+
+def load_three_lens():
+    three_lens = np.zeros((28, 28))
+    three_lens[8, 12] = 1 
+    three_lens[16, 20] = 1
+    three_lens[20, 7] = 1
+    three_lens = scipy.ndimage.gaussian_filter(three_lens, sigma=0.8)
+    three_lens /= np.sum(three_lens)
+    return three_lens
+
+
+def load_single_lens_32():
+    one_lens = np.zeros((32, 32))
+    one_lens[16, 16] = 1
+    one_lens = scipy.ndimage.gaussian_filter(one_lens, sigma=0.8)
+    one_lens /= np.sum(one_lens)
+    return one_lens
+
+def load_two_lens_32():
+    two_lens = np.zeros((32, 32))
+    two_lens[10, 10] = 1
+    two_lens[21, 21] = 1
+    two_lens = scipy.ndimage.gaussian_filter(two_lens, sigma=0.8)
+    two_lens /= np.sum(two_lens)
+    return two_lens
+
+def load_three_lens_32():
+    three_lens = np.zeros((32, 32))
+    three_lens[9, 12] = 1
+    three_lens[17, 22] = 1
+    three_lens[24, 8] = 1
+    three_lens = scipy.ndimage.gaussian_filter(three_lens, sigma=0.8)
+    three_lens /= np.sum(three_lens)
+    return three_lens
+
+def load_four_lens_32():
+    psf = np.zeros((32, 32))
+    psf[22, 22] = 1
+    psf[15, 10] = 1
+    psf[5, 12] = 1
+    psf[28, 8] = 1
+    psf = scipy.ndimage.gaussian_filter(psf, sigma=1) # note that this one is sigma 1, for mnist it's sigma 0.8
+    psf /= np.sum(psf)
+    return psf
+
+def load_diffuser_32():
+    diffuser_psf = skimage.io.imread('psfs/diffuser_psf.png')
+    diffuser_psf = diffuser_psf[:,:,1]
+    diffuser_resize = diffuser_psf[200:500, 250:550]
+    diffuser_resize = resize(diffuser_resize, (100, 100), anti_aliasing=True)  #resize(diffuser_psf, (28, 28))
+    diffuser_region = diffuser_resize[:32, :32]
+    diffuser_region /=  np.sum(diffuser_region)
+    return diffuser_region
+
+
+
+### 10/15/2023: Make new versions of the model functions that train with Datasets - first attempt failed
+
+# lenses with centralized positions for use in task-specific estimations
+def load_single_lens_uniform(size=32):
+    one_lens = np.zeros((size, size))
+    one_lens[16, 16] = 1
+    one_lens = scipy.ndimage.gaussian_filter(one_lens, sigma=0.8)
+    one_lens /= np.sum(one_lens)
+    return one_lens
+
+def load_two_lens_uniform(size=32):
+    two_lens = np.zeros((size, size))
+    two_lens[16, 16] = 1 
+    two_lens[7, 9] = 1
+    two_lens = scipy.ndimage.gaussian_filter(two_lens, sigma=0.8)
+    two_lens /= np.sum(two_lens)
+    return two_lens
+
+def load_three_lens_uniform(size=32):
+    three_lens = np.zeros((size, size))
+    three_lens[16, 16] = 1
+    three_lens[7, 9] = 1
+    three_lens[23, 21] = 1
+    three_lens = scipy.ndimage.gaussian_filter(three_lens, sigma=0.8)
+    three_lens /= np.sum(three_lens)
+    return three_lens
+
+def load_four_lens_uniform(size=32):
+    four_lens = np.zeros((size, size))
+    four_lens[16, 16] = 1
+    four_lens[7, 9] = 1
+    four_lens[23, 21] = 1
+    four_lens[8, 24] = 1
+    four_lens = scipy.ndimage.gaussian_filter(four_lens, sigma=0.8)
+    four_lens /= np.sum(four_lens)
+    return four_lens
+def load_five_lens_uniform(size=32):
+    five_lens = np.zeros((size, size))
+    five_lens[16, 16] = 1
+    five_lens[7, 9] = 1
+    five_lens[23, 21] = 1
+    five_lens[8, 24] = 1
+    five_lens[21, 5] = 1
+    five_lens = scipy.ndimage.gaussian_filter(five_lens, sigma=0.8)
+    five_lens /= np.sum(five_lens)
+    return five_lens
+
+
+
+## 01/24/2024 new CNN that's slightly deeper 
+def current_testing_model(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1, max_epochs=50, patience=20):
+    # structure from https://www.kaggle.com/code/amyjang/tensorflow-cifar10-cnn-tutorial 
+    
+    if seed_value == -1:
+        seed_val = np.random.randint(10, 1000)
+        tfk.utils.set_random_seed(seed_val)
+    else:
+        tfk.utils.set_random_seed(seed_value)
+
+    model = tf.keras.models.Sequential([
+    tf.keras.layers.Conv2D(32, kernel_size=5, padding='same', input_shape=(65, 65, 1), activation='relu'),
+    tf.keras.layers.Conv2D(32, kernel_size=5, activation='relu'),
+    tf.keras.layers.MaxPooling2D(),
+    tf.keras.layers.Dropout(0.25),
+
+    tf.keras.layers.Conv2D(64, kernel_size=3, padding='same', activation='relu'),
+    tf.keras.layers.Conv2D(64, kernel_size=3, activation='relu'),
+    tf.keras.layers.MaxPooling2D(),
+    tf.keras.layers.Dropout(0.25),
+
+    tf.keras.layers.Flatten(),
+    tf.keras.layers.Dense(512, activation='relu'),
+    tf.keras.layers.Dropout(0.5),
+    tf.keras.layers.Dense(128, activation='relu'),
+    tf.keras.layers.Dense(10, activation='softmax'),
+    ])
+
+    model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
+
+    early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option 
+                                            mode="min", patience=patience,
+                                            restore_best_weights=True, verbose=1)
+    print(model.optimizer.get_config())
+
+    history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), epochs=max_epochs, batch_size=32, callbacks=[early_stop]) #validation data is not test data
+    test_loss, test_acc = model.evaluate(test_data, test_labels)
+    return history, model, test_loss, test_acc
+
+
+
+
+## 01/24/2024 new CNN that's slightly deeper 
+def momentum_testing_model(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1, max_epochs=50, patience=20):
+    # structure from https://www.kaggle.com/code/amyjang/tensorflow-cifar10-cnn-tutorial 
+    # includes nesterov momentum feature, rather than regular momentum
+    if seed_value == -1:
+        seed_val = np.random.randint(10, 1000)
+        tfk.utils.set_random_seed(seed_val)
+    else:
+        tfk.utils.set_random_seed(seed_value)
+
+    model = tf.keras.models.Sequential([
+    tf.keras.layers.Conv2D(32, kernel_size=5, padding='same', input_shape=(65, 65, 1), activation='relu'),
+    tf.keras.layers.Conv2D(32, kernel_size=5, activation='relu'),
+    tf.keras.layers.MaxPooling2D(),
+    tf.keras.layers.Dropout(0.25),
+
+    tf.keras.layers.Conv2D(64, kernel_size=3, padding='same', activation='relu'),
+    tf.keras.layers.Conv2D(64, kernel_size=3, activation='relu'),
+    tf.keras.layers.MaxPooling2D(),
+    tf.keras.layers.Dropout(0.25),
+
+    tf.keras.layers.Flatten(),
+    tf.keras.layers.Dense(512, activation='relu'),
+    tf.keras.layers.Dropout(0.5),
+    tf.keras.layers.Dense(128, activation='relu'),
+    tf.keras.layers.Dense(10, activation='softmax'),
+    ])
+
+    model.compile(optimizer=SGD(momentum=0.9, nesterov=True), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
+
+    early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option 
+                                            mode="min", patience=patience,
+                                            restore_best_weights=True, verbose=1)
+    
+    print(model.optimizer.get_config())
+
+    history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), epochs=max_epochs, batch_size=32, callbacks=[early_stop]) #validation data is not test data
+    test_loss, test_acc = model.evaluate(test_data, test_labels)
+    return history, model, test_loss, test_acc
+
+
+# bootstrapping function 
+def compute_bootstraps(mses, psnrs, ssims, test_set_length, num_bootstraps=100): 
+    bootstrap_mses = []
+    bootstrap_psnrs = []
+    bootstrap_ssims = []
+    for bootstrap_idx in tqdm(range(num_bootstraps), desc='Bootstrapping to compute confidence interval'):
+        # select indices for sampling
+        bootstrap_indices = np.random.choice(test_set_length, test_set_length, replace=True) 
+        # take the metric values at those indices
+        bootstrap_selected_mses = mses[bootstrap_indices]
+        bootstrap_selected_psnrs = psnrs[bootstrap_indices]
+        bootstrap_selected_ssims = ssims[bootstrap_indices]
+        # accumulate the mean of the selected metric values
+        bootstrap_mses.append(np.mean(bootstrap_selected_mses))
+        bootstrap_psnrs.append(np.mean(bootstrap_selected_psnrs))
+        bootstrap_ssims.append(np.mean(bootstrap_selected_ssims))
+    bootstrap_mses = np.array(bootstrap_mses)
+    bootstrap_psnrs = np.array(bootstrap_psnrs)
+    bootstrap_ssims = np.array(bootstrap_ssims)
+    return bootstrap_mses, bootstrap_psnrs, bootstrap_ssims
+
+def compute_confidence_interval(list_of_items, confidence_interval=0.95):
+    # use this one, final version
+    assert confidence_interval > 0 and confidence_interval < 1
+    mean_value = np.mean(list_of_items)
+    lower_bound = np.percentile(list_of_items, 50 * (1 - confidence_interval))
+    upper_bound = np.percentile(list_of_items, 50 * (1 + confidence_interval))
+    return mean_value, lower_bound, upper_bound
+