From 13e3e73076c976a6306be41e865d7da484396f24 Mon Sep 17 00:00:00 2001 From: Leyla Kabuli <lakabuli@gmail.com> Date: Mon, 18 Nov 2024 13:24:05 -0800 Subject: [PATCH] lensless mi experiments with updated api --- ...upervised_wiener_deconvolution_per_lens.py | 176 +++++ ...s_deconvolution_plots_cifar10_figure.ipynb | 591 +++++++++++++++++ ...n_cifar10_updated_api_reruns_smaller_lr.py | 154 +++++ lensless_imager/lensless_helpers.py | 612 ++++++++++++++++++ 4 files changed, 1533 insertions(+) create mode 100644 lensless_imager/2024_10_22_sweep_unsupervised_wiener_deconvolution_per_lens.py create mode 100644 lensless_imager/2024_10_23_mi_vs_deconvolution_plots_cifar10_figure.ipynb create mode 100644 lensless_imager/2024_10_23_pixelcnn_cifar10_updated_api_reruns_smaller_lr.py create mode 100644 lensless_imager/lensless_helpers.py diff --git a/lensless_imager/2024_10_22_sweep_unsupervised_wiener_deconvolution_per_lens.py b/lensless_imager/2024_10_22_sweep_unsupervised_wiener_deconvolution_per_lens.py new file mode 100644 index 0000000..15fc594 --- /dev/null +++ b/lensless_imager/2024_10_22_sweep_unsupervised_wiener_deconvolution_per_lens.py @@ -0,0 +1,176 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.15.2 +# kernelspec: +# display_name: infotransformer +# language: python +# name: python3 +# --- + +# %% [markdown] +# ## Sweeping both unsupervised Wiener Deconvolution and non-unsupervised Wiener Deconvolution with hand-tuned paramete +# +# Using a fixed seed (10) for consistency. + +# %% +# %load_ext autoreload +# %autoreload 2 + +import os +from jax import config +config.update("jax_enable_x64", True) +import sys +sys.path.append('/home/lakabuli/workspace/EncodingInformation/src') + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = '1' +from encoding_information.gpu_utils import limit_gpu_memory_growth +limit_gpu_memory_growth() + + +from cleanplots import * +import numpy as np +import tensorflow as tf +import tensorflow.keras as tfk + +from lensless_helpers import * +from tqdm import tqdm + +# %% +from encoding_information.image_utils import add_noise +from skimage.restoration import wiener, unsupervised_wiener, richardson_lucy +import skimage.metrics as skm + +# %% +# load the PSFs + +diffuser_psf = load_diffuser_32() +one_psf = load_single_lens_uniform(32) +two_psf = load_two_lens_uniform(32) +three_psf = load_three_lens_uniform(32) +four_psf = load_four_lens_uniform(32) +five_psf = load_five_lens_uniform(32) +aperture_psf = np.copy(diffuser_psf) +aperture_psf[:5] = 0 +aperture_psf[-5:] = 0 +aperture_psf[:,:5] = 0 +aperture_psf[:,-5:] = 0 + + +# %% +def compute_skm_metrics(gt, recon): + # takes in already normalized gt + mse = skm.mean_squared_error(gt, recon) + psnr = skm.peak_signal_noise_ratio(gt, recon) + nmse = skm.normalized_root_mse(gt, recon) + ssim = skm.structural_similarity(gt, recon, data_range=1) + return mse, psnr, nmse, ssim + + +# %% +# set seed values for reproducibility +seed_values_full = np.arange(1, 4) + +# set photon properties +#mean_photon_count_list = [20, 40, 60, 80, 100, 150, 200, 250, 300] +mean_photon_count_list = [20, 40, 80, 160, 320] + +# set eligible psfs + +psf_patterns_use = [one_psf, four_psf, diffuser_psf] +psf_names_use = ['one', 'four', 'diffuser'] + +save_dir = '/home/lakabuli/workspace/EncodingInformation/lensless_imager/deconvolutions/' + + +# MI estimator parameters +patch_size = 32 +num_patches = 10000 +test_set_size = 1500 +bs = 500 +max_epochs = 50 + +seed_value = 10 + +reg_value_best = 10**-2 + +# %% +# data generation process + +for photon_count in mean_photon_count_list: + for psf_idx, psf_pattern in enumerate(psf_patterns_use): + # load dataset + (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data() + data = np.concatenate((x_train, x_test), axis=0) + data = data.astype(np.float64) + labels = np.concatenate((y_train, y_test), axis=0) # make one big glob of labels. + # convert data to grayscale before converting to photons + if len(data.shape) == 4: + data = tf.image.rgb_to_grayscale(data).numpy() + data = data.squeeze() + # convert to photons with mean value of photon_count + data /= np.mean(data) + data *= photon_count + # get maximum value in this data + max_val = np.max(data) + # make tiled data + random_data, random_labels = generate_random_tiled_data(data, labels, seed_value) + # only keep the middle part of the data + data_padded = np.zeros((data.shape[0], 96, 96)) + data_padded[:, 32:64, 32:64] = random_data[:, 32:64, 32:64] + # save the middle part of the data as the gt for metric computation, include only the test set portion. + gt_data = data_padded[:, 32:64, 32:64] + gt_data = gt_data[-test_set_size:] + # extract the test set before doing convolution + test_data = data_padded[-test_set_size:] + # convolve the data + convolved_data = convolved_dataset(psf_pattern, test_data) + convolved_data_noisy = add_noise(convolved_data, seed=seed_value) + # output of add_noise is a jax array that's float32, convert to regular numpy array and float64. + convolved_data_noisy = np.array(convolved_data_noisy).astype(np.float64) + + # compute metrics using unsupervised wiener deconvolution + mse_psf = [] + psnr_psf = [] + ssim_psf = [] + for i in tqdm(range(convolved_data_noisy.shape[0])): + recon, _ = unsupervised_wiener(convolved_data_noisy[i] / max_val, psf_pattern) + recon = recon[17:49, 17:49] #this is the crop window to look at + mse = skm.mean_squared_error(gt_data[i] / max_val, recon) + psnr = skm.peak_signal_noise_ratio(gt_data[i] / max_val, recon) + ssim = skm.structural_similarity(gt_data[i] / max_val, recon, data_range=1) + mse_psf.append(mse) + psnr_psf.append(psnr) + ssim_psf.append(ssim) + + print('PSF: {}, Mean MSE: {}, Mean PSNR: {}, Mean SSIM: {}'.format(psf_names_use[psf_idx], np.mean(mse_psf), np.mean(psnr_psf), np.mean(ssim_psf))) + np.save(save_dir + 'unsupervised_wiener_recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names_use[psf_idx]), [mse_psf, psnr_psf, ssim_psf]) + + # repeat with regular deconvolution + mse_psf = [] + psnr_psf = [] + ssim_psf = [] + for i in tqdm(range(convolved_data_noisy.shape[0])): + recon = wiener(convolved_data_noisy[i] / max_val, psf_pattern, reg_value_best) + recon = recon[17:49, 17:49] #this is the crop window to look at + mse = skm.mean_squared_error(gt_data[i] / max_val, recon) + psnr = skm.peak_signal_noise_ratio(gt_data[i] / max_val, recon) + ssim = skm.structural_similarity(gt_data[i] / max_val, recon, data_range=1) + mse_psf.append(mse) + psnr_psf.append(psnr) + ssim_psf.append(ssim) + print('PSF: {}, Mean MSE: {}, Mean PSNR: {}, Mean SSIM: {}'.format(psf_names_use[psf_idx], np.mean(mse_psf), np.mean(psnr_psf), np.mean(ssim_psf))) + np.save(save_dir + 'regular_wiener_recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names_use[psf_idx]), [mse_psf, psnr_psf, ssim_psf]) + + + + + +# %% + + diff --git a/lensless_imager/2024_10_23_mi_vs_deconvolution_plots_cifar10_figure.ipynb b/lensless_imager/2024_10_23_mi_vs_deconvolution_plots_cifar10_figure.ipynb new file mode 100644 index 0000000..97c9575 --- /dev/null +++ b/lensless_imager/2024_10_23_mi_vs_deconvolution_plots_cifar10_figure.ipynb @@ -0,0 +1,591 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Make the plot for MI and deconvolution relationship for paper figure, 2024/10/23" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload \n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from jax import config\n", + "config.update(\"jax_enable_x64\", True)\n", + "import numpy as np\n", + "\n", + "import sys \n", + "sys.path.append('/home/lakabuli/workspace/EncodingInformation/src')\n", + "from lensless_helpers import *\n", + "import os\n", + "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" \n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", + "print(os.environ.get('PYTHONPATH'))\n", + "from cleanplots import * " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "seed_value = 10\n", + "\n", + "# set photon properties \n", + "bias = 10 # in photons\n", + "mean_photon_count_list = [20, 40, 80, 160, 320]\n", + "max_photon_count = mean_photon_count_list[-1]\n", + "\n", + "# set eligible psfs\n", + "\n", + "psf_names = ['one', 'four', 'diffuser']\n", + "\n", + "# MI estimator parameters \n", + "patch_size = 32\n", + "num_patches = 10000\n", + "val_set_size = 1000\n", + "test_set_size = 1500\n", + "\n", + "mi_dir = '/home/lakabuli/workspace/EncodingInformation/lensless_imager/mi_estimates_smaller_lr/'\n", + "recon_dir = '/home/lakabuli/workspace/EncodingInformation/lensless_imager/deconvolutions/'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load MI data and make plots of it\n", + "\n", + "The plot has essentially invisible error bars. No outlier issues" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from cleanplots import *\n", + "get_color_cycle()[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1, figsize=(8, 6))\n", + "mis_across_psfs = []\n", + "lowers_across_psfs = []\n", + "uppers_across_psfs = []\n", + "for psf_name in psf_names:\n", + " mis = []\n", + " lowers = []\n", + " uppers = []\n", + " for photon_count in mean_photon_count_list:\n", + " mi_estimates = np.load(mi_dir + 'pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))\n", + " mi_values = mi_estimates[0]\n", + " print(np.max(mi_values) - np.min(mi_values))\n", + " lower_bounds = mi_estimates[1]\n", + " upper_bounds = mi_estimates[2]\n", + " # get index that has smallest mi value across the different model runs.\n", + " min_mi_index = np.argmin(mi_values)\n", + " mis.append(mi_values[min_mi_index])\n", + " lowers.append(lower_bounds[min_mi_index])\n", + " uppers.append(upper_bounds[min_mi_index])\n", + " ax.plot(mean_photon_count_list, mis, label=psf_name) \n", + " ax.fill_between(mean_photon_count_list, lowers, uppers, alpha=0.3)\n", + " mis_across_psfs.append(mis)\n", + " lowers_across_psfs.append(lowers)\n", + " uppers_across_psfs.append(uppers)\n", + "plt.legend()\n", + "plt.title(\"PixelCNN MI estimates across Photon Count, CIFAR10\")\n", + "plt.xlabel(\"Mean Photon Count\")\n", + "plt.ylabel(\"Estimated Mutual Information\")\n", + "mis_across_psfs = np.array(mis_across_psfs)\n", + "lowers_across_psfs = np.array(lowers_across_psfs)\n", + "uppers_across_psfs = np.array(uppers_across_psfs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load recon data and make plots of it" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mses_across_psfs = []\n", + "mse_lowers_across_psfs = []\n", + "mse_uppers_across_psfs = []\n", + "psnrs_across_psfs = []\n", + "psnr_lowers_across_psfs = []\n", + "psnr_uppers_across_psfs = []\n", + "ssims_across_psfs = []\n", + "ssim_lowers_across_psfs = []\n", + "ssim_uppers_across_psfs = []\n", + "\n", + "for psf_name in psf_names: \n", + " mse_vals = []\n", + " mse_lowers = []\n", + " mse_uppers = []\n", + " psnr_vals = []\n", + " psnr_lowers = []\n", + " psnr_uppers = []\n", + " ssim_vals = []\n", + " ssim_lowers = []\n", + " ssim_uppers = []\n", + " for photon_count in mean_photon_count_list:\n", + " metrics = np.load(recon_dir + 'unsupervised_wiener_recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))\n", + " mse = metrics[0]\n", + " psnr = metrics[1] \n", + " ssim = metrics[2]\n", + " bootstrap_mse, bootstrap_psnr, bootstrap_ssim = compute_bootstraps(mse, psnr, ssim, test_set_size)\n", + " mean_mse, lower_bound_mse, upper_bound_mse = compute_confidence_interval(bootstrap_mse, confidence_interval=0.95)\n", + " mean_psnr, lower_bound_psnr, upper_bound_psnr = compute_confidence_interval(bootstrap_psnr, confidence_interval=0.95)\n", + " mean_ssim, lower_bound_ssim, upper_bound_ssim = compute_confidence_interval(bootstrap_ssim, confidence_interval=0.95)\n", + " mse_vals.append(mean_mse)\n", + " mse_lowers.append(lower_bound_mse)\n", + " mse_uppers.append(upper_bound_mse)\n", + " psnr_vals.append(mean_psnr)\n", + " psnr_lowers.append(lower_bound_psnr)\n", + " psnr_uppers.append(upper_bound_psnr)\n", + " ssim_vals.append(mean_ssim)\n", + " ssim_lowers.append(lower_bound_ssim)\n", + " ssim_uppers.append(upper_bound_ssim)\n", + " mses_across_psfs.append(mse_vals)\n", + " mse_lowers_across_psfs.append(mse_lowers)\n", + " mse_uppers_across_psfs.append(mse_uppers)\n", + " psnrs_across_psfs.append(psnr_vals)\n", + " psnr_lowers_across_psfs.append(psnr_lowers)\n", + " psnr_uppers_across_psfs.append(psnr_uppers)\n", + " ssims_across_psfs.append(ssim_vals)\n", + " ssim_lowers_across_psfs.append(ssim_lowers)\n", + " ssim_uppers_across_psfs.append(ssim_uppers)\n", + "mses_across_psfs = np.array(mses_across_psfs)\n", + "mse_lowers_across_psfs = np.array(mse_lowers_across_psfs)\n", + "mse_uppers_across_psfs = np.array(mse_uppers_across_psfs)\n", + "psnrs_across_psfs = np.array(psnrs_across_psfs)\n", + "psnr_lowers_across_psfs = np.array(psnr_lowers_across_psfs)\n", + "psnr_uppers_across_psfs = np.array(psnr_uppers_across_psfs)\n", + "ssims_across_psfs = np.array(ssims_across_psfs)\n", + "ssim_lowers_across_psfs = np.array(ssim_lowers_across_psfs)\n", + "ssim_uppers_across_psfs = np.array(ssim_uppers_across_psfs)\n", + "plt.figure(figsize=(20, 5))\n", + "plt.subplot(1, 3, 1)\n", + "for i in range(len(psf_names)):\n", + " plt.plot(mean_photon_count_list, mses_across_psfs[i], label=psf_names[i])\n", + " plt.fill_between(mean_photon_count_list, mse_lowers_across_psfs[i], mse_uppers_across_psfs[i], alpha=0.5)\n", + "plt.title(\"MSE\")\n", + "plt.legend()\n", + "plt.subplot(1, 3, 2)\n", + "for i in range(len(psf_names)):\n", + " plt.plot(mean_photon_count_list, psnrs_across_psfs[i], label=psf_names[i])\n", + " plt.fill_between(mean_photon_count_list, psnr_lowers_across_psfs[i], psnr_uppers_across_psfs[i], alpha=0.5)\n", + "plt.title(\"PSNR\")\n", + "plt.subplot(1, 3, 3)\n", + "for i in range(len(psf_names)):\n", + " plt.plot(mean_photon_count_list, ssims_across_psfs[i], label=psf_names[i])\n", + " plt.fill_between(mean_photon_count_list, ssim_lowers_across_psfs[i], ssim_uppers_across_psfs[i], alpha=0.5)\n", + "plt.title(\"SSIM\")\n", + "plt.legend()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Make figures, omitting error bars since smaller than marker size and reverting to circular markers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def marker_for_psf(psf_name):\n", + " if psf_name =='one':\n", + " marker = 'o'\n", + " elif psf_name == 'four':\n", + " marker = 'o'\n", + " #marker = 's' \n", + " elif psf_name == 'diffuser':\n", + " #marker = '*'\n", + " marker = 'o'\n", + " elif psf_name == 'uc':\n", + " marker = 'x'\n", + " elif psf_name =='two':\n", + " marker = 'd'\n", + " elif psf_name == 'three':\n", + " marker = 'v'\n", + " elif psf_name == 'five':\n", + " marker = 'p'\n", + " elif psf_name == 'aperture':\n", + " marker = 'P'\n", + " return marker" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Choose a base colormap\n", + "base_colormap = plt.get_cmap('inferno')\n", + "# Define the start and end points--used so that high values aren't too light against white background\n", + "start, end = 0, 0.88 # making end point 0.8\n", + "from matplotlib.colors import LinearSegmentedColormap\n", + "# Create a new colormap from the portion of the original colormap\n", + "colormap = LinearSegmentedColormap.from_list(\n", + " 'trunc({n},{a:.2f},{b:.2f})'.format(n=base_colormap.name, a=start, b=end),\n", + " base_colormap(np.linspace(start, end, 256))\n", + ")\n", + "\n", + "min_photons_per_pixel = min(mean_photon_count_list)\n", + "max_photons_per_pixel = max(mean_photon_count_list)\n", + "\n", + "min_log_photons = np.log(min_photons_per_pixel)\n", + "max_log_photons = np.log(max_photons_per_pixel)\n", + "\n", + "def color_for_photon_level(photons_per_pixel):\n", + " log_photons = np.log(photons_per_pixel)\n", + " return colormap((log_photons - min_log_photons) / (max_log_photons - min_log_photons) )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# old format for selecting target indices, now not used much\n", + "metric_type = 1 # 0 for MSE, 1 for PSNR \n", + "valid_psfs = [0, 1, 2]\n", + "valid_photon_counts = [20, 40, 80, 160, 320]\n", + "psf_names = [psf_names[i] for i in valid_psfs]\n", + "print(psf_names)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mse_error_lower = np.abs(mses_across_psfs - mse_lowers_across_psfs)\n", + "mse_error_upper = np.abs(mse_uppers_across_psfs - mses_across_psfs)\n", + "psnr_error_lower = np.abs(psnrs_across_psfs - psnr_lowers_across_psfs)\n", + "psnr_error_upper = np.abs(psnr_uppers_across_psfs - psnrs_across_psfs)\n", + "mi_error_lower = np.abs(mis_across_psfs - lowers_across_psfs)\n", + "mi_error_upper = np.abs(uppers_across_psfs - mis_across_psfs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1, figsize=(6, 5))\n", + "for psf_idx, psf_name in enumerate(psf_names):\n", + " # plot all of the points here. \n", + " mi_means_across_photons = []\n", + " recon_means_across_photons = []\n", + " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", + " color = color_for_photon_level(photon_count) \n", + " mi_value = mis_across_psfs[psf_idx][photon_idx] \n", + " recon_value = mses_across_psfs[psf_idx][photon_idx] \n", + " ax.scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", + " # add to lists to track later \n", + " mi_means_across_photons.append(mi_value)\n", + " recon_means_across_photons.append(recon_value)\n", + " #ax.errorbar(mis_across_psfs[psf_idx], mses_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[mse_error_lower[psf_idx], mse_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n", + " \n", + " mi_means_across_photons = np.array(mi_means_across_photons)\n", + " recon_means_across_photons = np.array(recon_means_across_photons)\n", + " ax.plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n", + "ax.set_xlabel(\"Mutual Information (bits per pixel)\")\n", + "ax.set_ylabel(\"Mean Squared Error\")\n", + "clear_spines(ax)\n", + "\n", + "\n", + "# legend\n", + "# ax.scatter([], [], color='k', marker='o', label='One Lens')\n", + "# ax.scatter([], [], color='k', marker='s', label='Four Lens')\n", + "# ax.scatter([], [], color='k', marker='*', label='Diffuser')\n", + "\n", + "ax.legend(loc='upper right', frameon=True)\n", + "ax.set_xlim([0, None])\n", + "\n", + "\n", + "\n", + "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", + "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", + "sm.set_array([])\n", + "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n", + "# set tick labels\n", + "cbar.ax.set_yticklabels(valid_photon_counts)\n", + "\n", + "\n", + "cbar.set_label('Photons per pixel')\n", + "\n", + "#plt.savefig('mse_vs_MI_with_confidence_intervals_log_photons.pdf', bbox_inches='tight', transparent=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1, figsize=(6, 5))\n", + "for psf_idx, psf_name in enumerate(psf_names):\n", + " # plot all of the points here. \n", + " mi_means_across_photons = []\n", + " recon_means_across_photons = []\n", + " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", + " color = color_for_photon_level(photon_count) \n", + " mi_value = mis_across_psfs[psf_idx][photon_idx] \n", + " recon_value = psnrs_across_psfs[psf_idx][photon_idx] \n", + " ax.scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", + " # add to lists to track later \n", + " mi_means_across_photons.append(mi_value)\n", + " recon_means_across_photons.append(recon_value)\n", + " #ax.errorbar(mis_across_psfs[psf_idx], psnrs_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[psnr_error_lower[psf_idx], psnr_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n", + " \n", + " mi_means_across_photons = np.array(mi_means_across_photons)\n", + " recon_means_across_photons = np.array(recon_means_across_photons)\n", + " ax.plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n", + "ax.set_xlabel(\"Mutual Information (bits per pixel)\")\n", + "ax.set_ylabel(\"Peak Signal-to-Noise Ratio (dB)\")\n", + "clear_spines(ax)\n", + "\n", + "\n", + "# legend\n", + "# ax.scatter([], [], color='k', marker='o', label='One Lens')\n", + "# ax.scatter([], [], color='k', marker='s', label='Four Lens')\n", + "# ax.scatter([], [], color='k', marker='*', label='Diffuser')\n", + "\n", + "ax.legend(loc='lower right', frameon=True)\n", + "ax.set_xlim([0, None])\n", + "\n", + "\n", + "\n", + "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", + "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", + "sm.set_array([])\n", + "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n", + "# set tick labels\n", + "cbar.ax.set_yticklabels(valid_photon_counts)\n", + "\n", + "\n", + "cbar.set_label('Photons per pixel')\n", + "\n", + "#plt.savefig('psnr_vs_MI_with_confidence_intervals_log_photons.pdf', bbox_inches='tight', transparent=True)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1, figsize=(6, 5))\n", + "for psf_idx, psf_name in enumerate(psf_names):\n", + " # plot all of the points here. \n", + " mi_means_across_photons = []\n", + " recon_means_across_photons = []\n", + " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", + " color = color_for_photon_level(photon_count) \n", + " mi_value = mis_across_psfs[psf_idx][photon_idx] \n", + " recon_value = ssims_across_psfs[psf_idx][photon_idx] \n", + " ax.scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", + " # add to lists to track later \n", + " mi_means_across_photons.append(mi_value)\n", + " recon_means_across_photons.append(recon_value)\n", + " #ax.errorbar(mis_across_psfs[psf_idx], ssims_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[ssim_error_lower[psf_idx], ssim_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n", + " \n", + " mi_means_across_photons = np.array(mi_means_across_photons)\n", + " recon_means_across_photons = np.array(recon_means_across_photons)\n", + " ax.plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n", + "ax.set_xlabel(\"Mutual Information (bits per pixel)\")\n", + "ax.set_ylabel(\"Structural Similarity Index Measure (SSIM)\")\n", + "clear_spines(ax)\n", + "\n", + "\n", + "# legend\n", + "# ax.scatter([], [], color='k', marker='o', label='One Lens')\n", + "# ax.scatter([], [], color='k', marker='s', label='Four Lens')\n", + "# ax.scatter([], [], color='k', marker='*', label='Diffuser')\n", + "\n", + "ax.legend(loc='lower right', frameon=True)\n", + "ax.set_xlim([0, None])\n", + "\n", + "\n", + "\n", + "norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", + "sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", + "sm.set_array([])\n", + "cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))\n", + "# set tick labels\n", + "cbar.ax.set_yticklabels(valid_photon_counts)\n", + "\n", + "\n", + "cbar.set_label('Photons per pixel')\n", + "\n", + "#plt.savefig('ssim_vs_MI_with_confidence_intervals_log_photons.pdf', bbox_inches='tight', transparent=True)\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Put all 3 into one figure" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import glob\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from cleanplots import *\n", + "from matplotlib.ticker import ScalarFormatter\n", + "\n", + "figs, axs = plt.subplots(1, 3, figsize=(12, 4), sharex=True)\n", + "\n", + "\n", + "for psf_idx, psf_name in enumerate(psf_names):\n", + " # plot all of the points here. \n", + " mi_means_across_photons = []\n", + " recon_means_across_photons = []\n", + " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", + " color = color_for_photon_level(photon_count) \n", + " mi_value = mis_across_psfs[psf_idx][photon_idx] \n", + " recon_value = mses_across_psfs[psf_idx][photon_idx] \n", + " axs[0].scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", + " # add to lists to track later \n", + " mi_means_across_photons.append(mi_value)\n", + " recon_means_across_photons.append(recon_value)\n", + " #ax.errorbar(mis_across_psfs[psf_idx], mses_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[mse_error_lower[psf_idx], mse_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n", + " \n", + " mi_means_across_photons = np.array(mi_means_across_photons)\n", + " recon_means_across_photons = np.array(recon_means_across_photons)\n", + " axs[0].plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n", + "#axs[0].set_xlabel(\"Mutual Information (bits per pixel)\")\n", + "axs[0].set_title(\"Mean Squared Error\")\n", + "clear_spines(axs[0])\n", + "\n", + "for psf_idx, psf_name in enumerate(psf_names):\n", + " # plot all of the points here. \n", + " mi_means_across_photons = []\n", + " recon_means_across_photons = []\n", + " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", + " color = color_for_photon_level(photon_count) \n", + " mi_value = mis_across_psfs[psf_idx][photon_idx] \n", + " recon_value = ssims_across_psfs[psf_idx][photon_idx] \n", + " axs[1].scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", + " # add to lists to track later \n", + " mi_means_across_photons.append(mi_value)\n", + " recon_means_across_photons.append(recon_value)\n", + " #ax.errorbar(mis_across_psfs[psf_idx], ssims_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[ssim_error_lower[psf_idx], ssim_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n", + " \n", + " mi_means_across_photons = np.array(mi_means_across_photons)\n", + " recon_means_across_photons = np.array(recon_means_across_photons)\n", + " axs[1].plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n", + "axs[1].set_xlabel(\"Mutual Information (bits per pixel)\")\n", + "axs[1].set_title(\"Structural Similarity Index Measure (SSIM)\")\n", + "clear_spines(axs[1])\n", + "\n", + "for psf_idx, psf_name in enumerate(psf_names):\n", + " # plot all of the points here. \n", + " mi_means_across_photons = []\n", + " recon_means_across_photons = []\n", + " for photon_idx, photon_count in enumerate(mean_photon_count_list):\n", + " color = color_for_photon_level(photon_count) \n", + " mi_value = mis_across_psfs[psf_idx][photon_idx] \n", + " recon_value = psnrs_across_psfs[psf_idx][photon_idx] \n", + " axs[2].scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)\n", + " # add to lists to track later \n", + " mi_means_across_photons.append(mi_value)\n", + " recon_means_across_photons.append(recon_value)\n", + " #ax.errorbar(mis_across_psfs[psf_idx], psnrs_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[psnr_error_lower[psf_idx], psnr_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)\n", + " \n", + " mi_means_across_photons = np.array(mi_means_across_photons)\n", + " recon_means_across_photons = np.array(recon_means_across_photons)\n", + " axs[2].plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)\n", + "#axs[2].set_xlabel(\"Mutual Information (bits per pixel)\")\n", + "axs[2].set_title(\"Peak Signal-to-Noise Ratio (dB)\")\n", + "clear_spines(axs[2])\n", + "\n", + "# norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)\n", + "# sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)\n", + "# sm.set_array([])\n", + "# cbar = plt.colorbar(sm, ax=axs[2], ticks=(np.log(valid_photon_counts)))\n", + "# # set tick labels\n", + "# cbar.ax.set_yticklabels(valid_photon_counts)\n", + "\n", + "\n", + "# cbar.set_label('Photons per pixel')\n", + "\n", + "#plt.savefig(\"metrics_vs_MI_with_confidence_intervals_log_photons.pdf\", bbox_inches='tight', transparent=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "infotransformer", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/lensless_imager/2024_10_23_pixelcnn_cifar10_updated_api_reruns_smaller_lr.py b/lensless_imager/2024_10_23_pixelcnn_cifar10_updated_api_reruns_smaller_lr.py new file mode 100644 index 0000000..bd9148a --- /dev/null +++ b/lensless_imager/2024_10_23_pixelcnn_cifar10_updated_api_reruns_smaller_lr.py @@ -0,0 +1,154 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.15.2 +# kernelspec: +# display_name: infotransformer +# language: python +# name: python3 +# --- + +# %% +# %load_ext autoreload +# %autoreload 2 + +# Final MI estimation script for lensless imager, used in paper. + +import os +from jax import config +config.update("jax_enable_x64", True) +import sys +sys.path.append('/home/lakabuli/workspace/EncodingInformation/src') + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = '0' +from encoding_information.gpu_utils import limit_gpu_memory_growth +limit_gpu_memory_growth() + +from cleanplots import * +import jax.numpy as np +import numpy as onp +import tensorflow as tf +import tensorflow.keras as tfk + + +from lensless_helpers import * + +# %% +from encoding_information import extract_patches +from encoding_information.models import PixelCNN +from encoding_information.plot_utils import plot_samples +from encoding_information.models import PoissonNoiseModel +from encoding_information.image_utils import add_noise +from encoding_information import estimate_information + +# %% [markdown] +# ### Sweep Photon Count and Diffusers + +# %% +diffuser_psf = load_diffuser_32() +one_psf = load_single_lens_uniform(32) +two_psf = load_two_lens_uniform(32) +three_psf = load_three_lens_uniform(32) +four_psf = load_four_lens_uniform(32) +five_psf = load_five_lens_uniform(32) + +# %% +# set seed values for reproducibility +seed_values_full = np.arange(1, 5) + +# set photon properties +bias = 10 # in photons +mean_photon_count_list = [20, 40, 80, 160, 320] + +# set eligible psfs + +psf_patterns = [diffuser_psf, four_psf, one_psf] +psf_names = ['diffuser', 'four', 'one'] + +# MI estimator parameters +patch_size = 32 +num_patches = 10000 +val_set_size = 1000 +test_set_size = 1500 +num_samples = 8 +learning_rate = 1e-3 # using 5x iterations per epoch, using smaller lr, and using less patience since it should be a smoother curve. +num_iters_per_epoch = 500 +patience_val = 20 + + +save_dir = '/home/lakabuli/workspace/EncodingInformation/lensless_imager/mi_estimates_smaller_lr/' + + +# %% +for photon_count in mean_photon_count_list: + for index, psf_pattern in enumerate(psf_patterns): + val_loss_log = [] + mi_estimates = [] + lower_bounds = [] + upper_bounds = [] + for seed_value in seed_values_full: + # load dataset + (x_train, y_train), (x_test, y_test) = tfk.datasets.cifar10.load_data() + data = onp.concatenate((x_train, x_test), axis=0) + labels = np.concatenate((y_train, y_test), axis=0) + data = data.astype(np.float32) + # convert data to grayscale before converting to photons + if len(data.shape) == 4: + data = tf.image.rgb_to_grayscale(data).numpy() + data = data.squeeze() + # convert to photons with mean value of photon_count + data /= onp.mean(data) + data *= photon_count + # make tiled data + random_data, random_labels = generate_random_tiled_data(data, labels, seed_value) + + if psf_pattern is None: + start_idx = data.shape[-1] // 2 + end_idx = data.shape[-1] // 2 - 1 + psf_data = random_data[:, start_idx:-end_idx, start_idx:-end_idx] + else: + psf_data = convolved_dataset(psf_pattern, random_data) + # add small bias to data + psf_data += bias + # make patches for training and testing splits, random patching + patches = extract_patches(psf_data[:-test_set_size], patch_size=patch_size, num_patches=num_patches, seed=seed_value, verbose=True) + test_patches = extract_patches(psf_data[-test_set_size:], patch_size=patch_size, num_patches=test_set_size, seed=seed_value, verbose=True) + # put all the clean patches together for use in MI estimatino function later + full_clean_patches = onp.concatenate([patches, test_patches]) + # add noise to both sets + patches_noisy = add_noise(patches, seed=seed_value) + test_patches_noisy = add_noise(test_patches, seed=seed_value) + + # initialize pixelcnn + pixel_cnn = PixelCNN() + # fit pixelcnn to noisy patches. defaults to 10% val samples which will be 1k as desired. + # using smaller lr this time and adding seeding, letting it go for full training time. + val_loss_history = pixel_cnn.fit(patches_noisy, seed=seed_value, learning_rate=learning_rate, do_lr_decay=False, steps_per_epoch=num_iters_per_epoch, patience=patience_val) + # generate samples, not necessary for MI sweeps + # pixel_cnn_samples = pixel_cnn.generate_samples(num_samples=num_samples) + # # visualize samples + # plot_samples([pixel_cnn_samples], test_patches, model_names=['PixelCNN']) + + # instantiate noise model + noise_model = PoissonNoiseModel() + # estimate information using the fit pixelcnn and noise model, with clean data + pixel_cnn_info, pixel_cnn_lower_bound, pixel_cnn_upper_bound = estimate_information(pixel_cnn, noise_model, patches_noisy, + test_patches_noisy, clean_data=full_clean_patches, + confidence_interval=0.95) + print("PixelCNN estimated information: ", pixel_cnn_info) + print("PixelCNN lower bound: ", pixel_cnn_lower_bound) + print("PixelCNN upper bound: ", pixel_cnn_upper_bound) + # append results to lists + val_loss_log.append(val_loss_history) + mi_estimates.append(pixel_cnn_info) + lower_bounds.append(pixel_cnn_lower_bound) + upper_bounds.append(pixel_cnn_upper_bound) + np.save(save_dir + 'pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array([mi_estimates, lower_bounds, upper_bounds])) + np.save(save_dir + 'pixelcnn_val_loss_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(val_loss_log, dtype=object)) + np.save(save_dir + 'pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array([mi_estimates, lower_bounds, upper_bounds])) + np.save(save_dir + 'pixelcnn_val_loss_{}_photon_count_{}_psf.npy'.format(photon_count, psf_names[index]), np.array(val_loss_log, dtype=object)) diff --git a/lensless_imager/lensless_helpers.py b/lensless_imager/lensless_helpers.py new file mode 100644 index 0000000..0693b28 --- /dev/null +++ b/lensless_imager/lensless_helpers.py @@ -0,0 +1,612 @@ +import numpy as np # use regular numpy for now, simpler +import scipy +from tqdm import tqdm +# import tensorflow as tf +# import tensorflow.keras as tfk +import gc +import warnings + +import skimage +import skimage.io +from skimage.transform import resize + +# from tensorflow.keras.optimizers import SGD + +def tile_9_images(data_set): + # takes 9 images and forms a tiled image + assert len(data_set) == 9 + return np.block([[data_set[0], data_set[1], data_set[2]],[data_set[3], data_set[4], data_set[5]],[data_set[6], data_set[7], data_set[8]]]) + +def generate_random_tiled_data(x_set, y_set, seed_value=-1): + # takes a set of images and labels and returns a set of tiled images and corresponding labels + # the size of the output should be 3x the size of the input + vert_shape = x_set.shape[1] * 3 + horiz_shape = x_set.shape[2] * 3 + random_data = np.zeros((x_set.shape[0], vert_shape, horiz_shape)) # for mnist this was 84 x 84 + random_labels = np.zeros((y_set.shape[0], 1)) + if seed_value==-1: + np.random.seed() + else: + np.random.seed(seed_value) + for i in range(x_set.shape[0]): + img_items = np.random.choice(x_set.shape[0], size=9, replace=True) + data_set = x_set[img_items] + random_labels[i] = y_set[img_items[4]] + random_data[i] = tile_9_images(data_set) + return random_data, random_labels + +def generate_repeated_tiled_data(x_set, y_set): + # takes set of images and labels and returns a set of repeated tiled images and corresponding labels, no randomness + # the size of the output is 3x the size of the input, this essentially is a wrapper for np.tile + repeated_data = np.tile(x_set, (1, 3, 3)) + repeated_labels = y_set # the labels are just what they were + return repeated_data, repeated_labels + +def convolved_dataset(psf, random_tiled_data): + # takes a psf and a set of tiled images and returns a set of convolved images, convolved image size is 2n + 1? same size as the random data when it's cropped + # tile size is two images worth plus one extra index value + vert_shape = psf.shape[0] * 2 + 1 + horiz_shape = psf.shape[1] * 2 + 1 + psf_dataset = np.zeros((random_tiled_data.shape[0], vert_shape, horiz_shape)) # 57 x 57 for the case of mnist 28x28 images, 65 x 65 for the cifar 32 x 32 images + for i in tqdm(range(random_tiled_data.shape[0])): + psf_dataset[i] = scipy.signal.fftconvolve(psf, random_tiled_data[i], mode='valid') + return psf_dataset + +def compute_entropy(eigenvalues): + sum_log_evs = np.sum(np.log2(eigenvalues)) + D = eigenvalues.shape[0] + gaussian_entropy = 0.5 * (sum_log_evs + D * np.log2(2 * np.pi * np.e)) + return gaussian_entropy + +def add_shot_noise(photon_scaled_images, photon_fraction=None, photons_per_pixel=None, assume_noiseless=True, seed_value=-1): + #adapted from henry, also uses a seed though + if seed_value==-1: + np.random.seed() + else: + np.random.seed(seed_value) + + # check all pixels greater than 0 + if np.any(photon_scaled_images < 0): + #warning about negative + warnings.warn(f"Negative pixel values detected. Clipping to 0.") + photon_scaled_images[photon_scaled_images < 0] = 0 + if photons_per_pixel is not None: + if photons_per_pixel > np.mean(photon_scaled_images): + warnings.warn(f"photons_per_pixel is greater than actual photon count ({photons_per_pixel}). Clipping to {np.mean(photon_scaled_images)}") + photons_per_pixel = np.mean(photon_scaled_images) + photon_fraction = photons_per_pixel / np.mean(photon_scaled_images) + + if photon_fraction > 1: + warnings.warn(f"photon_fraction is greater than 1 ({photon_fraction}). Clipping to 1.") + photon_fraction = 1 + + if assume_noiseless: + additional_sd = np.sqrt(photon_fraction * photon_scaled_images) + if np.any(np.isnan(additional_sd)): + warnings.warn('There are nans here') + additional_sd[np.isnan(additional_sd)] = 0 + # something here goes weird for RML + # + #else: + # additional_sd = np.sqrt(photon_fraction * photon_scaled_images) - photon_fraction * np.sqrt(photon_scaled_images) + simulated_images = photon_scaled_images * photon_fraction + additional_sd * np.random.randn(*photon_scaled_images.shape) + positive = np.array(simulated_images) + positive[positive < 0] = 0 # cant have negative counts + return np.array(positive) + +def tf_cast(data): + # normalizes data, loads it to a tensorflow array of type float32 + return tf.cast(data / np.max(data), tf.float32) +def tf_labels(labels): + # loads labels to a tensorflow array of type int64 + return tf.cast(labels, tf.int64) + + + +def run_model_simple(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1): + if seed_value == -1: + seed_val = np.random.randint(10, 1000) + tfk.utils.set_random_seed(seed_val) + else: + tfk.utils.set_random_seed(seed_value) + + model = tfk.models.Sequential() + model.add(tfk.layers.Flatten()) + model.add(tfk.layers.Dense(256, activation='relu')) + model.add(tfk.layers.Dense(256, activation='relu')) + model.add(tfk.layers.Dense(10, activation='softmax')) + + model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy']) + + early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option + mode="min", patience=5, + restore_best_weights=True, verbose=1) + history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), batch_size=32, epochs=50, callbacks=[early_stop]) + test_loss, test_acc = model.evaluate(test_data, test_labels) + return history, model, test_loss, test_acc + +def run_model_cnn(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1): + # structure from https://www.kaggle.com/code/cdeotte/how-to-choose-cnn-architecture-mnist + if seed_value == -1: + seed_val = np.random.randint(10, 1000) + tfk.utils.set_random_seed(seed_val) + else: + tfk.utils.set_random_seed(seed_value) + + model = tfk.models.Sequential() + model.add(tfk.layers.Conv2D(64, kernel_size=5, padding='same', activation='relu', input_shape=(57, 57, 1))) #64 and 128 works very slightly better + model.add(tfk.layers.MaxPool2D()) + model.add(tfk.layers.Conv2D(128, kernel_size=5, padding='same', activation='relu')) + model.add(tfk.layers.MaxPool2D()) + #model.add(tfk.layers.Conv2D(64, kernel_size=5, padding='same', activation='relu')) + #model.add(tfk.layers.MaxPool2D(padding='same')) + model.add(tfk.layers.Flatten()) + + #model.add(tfk.layers.Dense(256, activation='relu')) + model.add(tfk.layers.Dense(128, activation='relu')) + model.add(tfk.layers.Dense(10, activation='softmax')) + + model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy']) + + early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option + mode="min", patience=5, + restore_best_weights=True, verbose=1) + history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), epochs=50, batch_size=32, callbacks=[early_stop]) #validation data is not test data + test_loss, test_acc = model.evaluate(test_data, test_labels) + return history, model, test_loss, test_acc + +def seeded_permutation(seed_value, n): + # given fixed seed returns permutation order + np.random.seed(seed_value) + permutation_order = np.random.permutation(n) + return permutation_order + +def segmented_indices(permutation_order, n, training_fraction, test_fraction): + #given permutation order returns indices for each of the three sets + training_indices = permutation_order[:int(training_fraction*n)] + test_indices = permutation_order[int(training_fraction*n):int((training_fraction+test_fraction)*n)] + validation_indices = permutation_order[int((training_fraction+test_fraction)*n):] + return training_indices, test_indices, validation_indices + +def permute_data(data, labels, seed_value, training_fraction=0.8, test_fraction=0.1): + #validation fraction is implicit, if including a validation set, expect to use the remaining fraction of the data + permutation_order = seeded_permutation(seed_value, data.shape[0]) + training_indices, test_indices, validation_indices = segmented_indices(permutation_order, data.shape[0], training_fraction, test_fraction) + + training_data = data[training_indices] + training_labels = labels[training_indices] + testing_data = data[test_indices] + testing_labels = labels[test_indices] + validation_data = data[validation_indices] + validation_labels = labels[validation_indices] + + return (training_data, training_labels), (testing_data, testing_labels), (validation_data, validation_labels) + +def add_gaussian_noise(data, noise_level, seed_value=-1): + if seed_value==-1: + np.random.seed() + else: + np.random.seed(seed_value) + return data + noise_level * np.random.randn(*data.shape) + +def confidence_bars(data_array, noise_length, confidence_interval=0.95): + # can also use confidence interval 0.9 or 0.99 if want slightly different bounds + error_lo = np.percentile(data_array, 100 * (1 - confidence_interval) / 2, axis=1) + error_hi = np.percentile(data_array, 100 * (1 - (1 - confidence_interval) / 2), axis=1) + mean = np.mean(data_array, axis=1) + assert len(error_lo) == len(mean) == len(error_hi) == noise_length + return error_lo, error_hi, mean + + +######### This function is very outdated, don't use it!! used to be called test_system use the ones below instead +######### +def test_system_old(noise_level, psf_name, model_name, seed_values, data, labels, training_fraction, testing_fraction, diffuser_region, phlat_region, psf, noise_type, rml_region): + # runs the model for the number of seeds given, returns the test accuracy for each seed + test_accuracy_list = [] + for seed_value in seed_values: + seed_value = int(seed_value) + tfk.backend.clear_session() + gc.collect() + tfk.utils.set_random_seed(seed_value) # set random seed out here too? + training, testing, validation = permute_data(data, labels, seed_value, training_fraction, testing_fraction) + x_train, y_train = training + x_test, y_test = testing + x_validation, y_validation = validation + + random_test_data, random_test_labels = generate_random_tiled_data(x_test, y_test, seed_value) + random_train_data, random_train_labels = generate_random_tiled_data(x_train, y_train, seed_value) + random_valid_data, random_valid_labels = generate_random_tiled_data(x_validation, y_validation, seed_value) + + if psf_name == 'uc': + test_data = random_test_data[:, 14:-13, 14:-13] + train_data = random_train_data[:, 14:-13, 14:-13] + valid_data = random_valid_data[:, 14:-13, 14:-13] + if psf_name == 'psf_4': + test_data = convolved_dataset(psf, random_test_data) + train_data = convolved_dataset(psf, random_train_data) + valid_data = convolved_dataset(psf, random_valid_data) + if psf_name == 'diffuser': + test_data = convolved_dataset(diffuser_region, random_test_data) + train_data = convolved_dataset(diffuser_region, random_train_data) + valid_data = convolved_dataset(diffuser_region, random_valid_data) + if psf_name == 'phlat': + test_data = convolved_dataset(phlat_region, random_test_data) + train_data = convolved_dataset(phlat_region, random_train_data) + valid_data = convolved_dataset(phlat_region, random_valid_data) + # 6/19/23 added RML option + if psf_name == 'rml': + test_data = convolved_dataset(rml_region, random_test_data) + train_data = convolved_dataset(rml_region, random_train_data) + valid_data = convolved_dataset(rml_region, random_valid_data) + + # address any tiny floating point negative values, which only occur in RML data + if np.any(test_data < 0): + #print('negative values in test data for {} psf'.format(psf_name)) + test_data[test_data < 0] = 0 + if np.any(train_data < 0): + #print('negative values in train data for {} psf'.format(psf_name)) + train_data[train_data < 0] = 0 + if np.any(valid_data < 0): + #print('negative values in valid data for {} psf'.format(psf_name)) + valid_data[valid_data < 0] = 0 + + + # additive gaussian noise, add noise after convolving, fixed 5/15/2023 + if noise_type == 'gaussian': + test_data = add_gaussian_noise(test_data, noise_level, seed_value) + train_data = add_gaussian_noise(train_data, noise_level, seed_value) + valid_data = add_gaussian_noise(valid_data, noise_level, seed_value) + if noise_type == 'poisson': + test_data = add_shot_noise(test_data, photons_per_pixel=noise_level, seed_value=seed_value, assume_noiseless=True) + train_data = add_shot_noise(train_data, photons_per_pixel=noise_level, seed_value=seed_value, assume_noiseless=True) + valid_data = add_shot_noise(valid_data, photons_per_pixel=noise_level, seed_value=seed_value, assume_noiseless=True) + + train_data, test_data, valid_data = tf_cast(train_data), tf_cast(test_data), tf_cast(valid_data) + random_train_labels, random_test_labels, random_valid_labels = tf_labels(random_train_labels), tf_labels(random_test_labels), tf_labels(random_valid_labels) + + if model_name == 'simple': + history, model, test_loss, test_acc = run_model_simple(train_data, random_train_labels, test_data, random_test_labels, valid_data, random_valid_labels, seed_value) + if model_name == 'cnn': + history, model, test_loss, test_acc = run_model_cnn(train_data, random_train_labels, test_data, random_test_labels, valid_data, random_valid_labels, seed_value) + test_accuracy_list.append(test_acc) + np.save('classification_results_rml_psf_619/test_accuracy_{}_noise_{}_{}_psf_{}_model.npy'.format(noise_level, noise_type, psf_name, model_name), test_accuracy_list) + + ###### CNN for 32x32 CIFAR10 images + # Originally written 11/14/2023, but then lost in a merge, recopied 1/14/2024 +def run_model_cnn_cifar(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1, max_epochs=50, patience=5): + # structure from https://www.kaggle.com/code/cdeotte/how-to-choose-cnn-architecture-mnist + # default architecture is 50 epochs and patience 5, but recently some need longer patience + if seed_value == -1: + seed_val = np.random.randint(10, 1000) + tfk.utils.set_random_seed(seed_val) + else: + tfk.utils.set_random_seed(seed_value) + model = tfk.models.Sequential() + model.add(tfk.layers.Conv2D(64, kernel_size=5, padding='same', activation='relu', input_shape=(65, 65, 1))) + model.add(tfk.layers.MaxPool2D()) + model.add(tfk.layers.Conv2D(128, kernel_size=5, padding='same', activation='relu')) + model.add(tfk.layers.MaxPool2D()) + #model.add(tfk.layers.Conv2D(64, kernel_size=5, padding='same', activation='relu')) + #model.add(tfk.layers.MaxPool2D(padding='same')) + model.add(tfk.layers.Flatten()) + + #model.add(tfk.layers.Dense(256, activation='relu')) + model.add(tfk.layers.Dense(128, activation='relu')) + model.add(tfk.layers.Dense(10, activation='softmax')) + + model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy']) + + early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option + mode="min", patience=patience, + restore_best_weights=True, verbose=1) + history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), epochs=max_epochs, batch_size=32, callbacks=[early_stop]) #validation data is not test data + test_loss, test_acc = model.evaluate(test_data, test_labels) + return history, model, test_loss, test_acc + +def make_ttv_sets(data, labels, seed_value, training_fraction, testing_fraction): + training, testing, validation = permute_data(data, labels, seed_value, training_fraction, testing_fraction) + training_data, training_labels = training + testing_data, testing_labels = testing + validation_data, validation_labels = validation + training_data, testing_data, validation_data = tf_cast(training_data), tf_cast(testing_data), tf_cast(validation_data) + training_labels, testing_labels, validation_labels = tf_labels(training_labels), tf_labels(testing_labels), tf_labels(validation_labels) + return (training_data, training_labels), (testing_data, testing_labels), (validation_data, validation_labels) + +def run_network_cifar(data, labels, seed_value, training_fraction, testing_fraction, mode='cnn', max_epochs=50, patience=5): + # small modification to be able to run 32x32 image data + training, testing, validation = make_ttv_sets(data, labels, seed_value, training_fraction, testing_fraction) + if mode == 'cnn': + history, model, test_loss, test_acc = run_model_cnn_cifar(training[0], training[1], + testing[0], testing[1], + validation[0], validation[1], seed_value, max_epochs, patience) + elif mode == 'simple': + history, model, test_loss, test_acc = run_model_simple(training[0], training[1], + testing[0], testing[1], + validation[0], validation[1], seed_value) + elif mode == 'new_cnn': + history, model, test_loss, test_acc = current_testing_model(training[0], training[1], + testing[0], testing[1], + validation[0], validation[1], seed_value, max_epochs, patience) + elif mode == 'mom_cnn': + history, model, test_loss, test_acc = momentum_testing_model(training[0], training[1], + testing[0], testing[1], + validation[0], validation[1], seed_value, max_epochs, patience) + return history, model, test_loss, test_acc + + +def load_diffuser_psf(): + diffuser_psf = skimage.io.imread('psfs/diffuser_psf.png') + diffuser_psf = diffuser_psf[:,:,1] + diffuser_resize = diffuser_psf[200:500, 250:550] + diffuser_resize = resize(diffuser_resize, (100, 100), anti_aliasing=True) #resize(diffuser_psf, (28, 28)) + diffuser_region = diffuser_resize[:28, :28] + diffuser_region /= np.sum(diffuser_region) + return diffuser_region + +def load_phlat_psf(): + phlat_psf = skimage.io.imread('psfs/phlat_psf.png') + phlat_psf = phlat_psf[900:2900, 1500:3500, 1] + phlat_psf = resize(phlat_psf, (200, 200), anti_aliasing=True) + phlat_region = phlat_psf[10:38, 20:48] + phlat_region /= np.sum(phlat_region) + return phlat_region + +def load_4_psf(): + psf = np.zeros((28, 28)) + psf[20,20] = 1 + psf[15, 10] = 1 + psf[5, 13] = 1 + psf[23, 6] = 1 + psf = scipy.ndimage.gaussian_filter(psf, sigma=1) + psf /= np.sum(psf) + return psf + +# 6/9/23 added rml option +def load_rml_psf(): + rml_psf = skimage.io.imread('psfs/psf_8holes.png') + rml_psf = rml_psf[1000:3000, 1500:3500] + rml_psf_resize = resize(rml_psf, (100, 100), anti_aliasing=True) + rml_psf_region = rml_psf_resize[40:100, :60] + rml_psf_region = resize(rml_psf_region, (28, 28), anti_aliasing=True) + rml_psf_region /= np.sum(rml_psf_region) + return rml_psf_region + +def load_rml_new_psf(): + rml_psf = skimage.io.imread('psfs/psf_8holes.png') + rml_psf = rml_psf[1000:3000, 1500:3500] + rml_psf_small = resize(rml_psf, (85, 85), anti_aliasing=True) + rml_psf_region = rml_psf_small[52:80, 10:38] + rml_psf_region /= np.sum(rml_psf_region) + return rml_psf_region + +def load_single_lens(): + one_lens = np.zeros((28, 28)) + one_lens[14, 14] = 1 + one_lens = scipy.ndimage.gaussian_filter(one_lens, sigma=0.8) + one_lens /= np.sum(one_lens) + return one_lens + +def load_two_lens(): + two_lens = np.zeros((28, 28)) + two_lens[10, 10] = 1 + two_lens[20, 20] = 1 + two_lens = scipy.ndimage.gaussian_filter(two_lens, sigma=0.8) + two_lens /= np.sum(two_lens) + return two_lens + +def load_three_lens(): + three_lens = np.zeros((28, 28)) + three_lens[8, 12] = 1 + three_lens[16, 20] = 1 + three_lens[20, 7] = 1 + three_lens = scipy.ndimage.gaussian_filter(three_lens, sigma=0.8) + three_lens /= np.sum(three_lens) + return three_lens + + +def load_single_lens_32(): + one_lens = np.zeros((32, 32)) + one_lens[16, 16] = 1 + one_lens = scipy.ndimage.gaussian_filter(one_lens, sigma=0.8) + one_lens /= np.sum(one_lens) + return one_lens + +def load_two_lens_32(): + two_lens = np.zeros((32, 32)) + two_lens[10, 10] = 1 + two_lens[21, 21] = 1 + two_lens = scipy.ndimage.gaussian_filter(two_lens, sigma=0.8) + two_lens /= np.sum(two_lens) + return two_lens + +def load_three_lens_32(): + three_lens = np.zeros((32, 32)) + three_lens[9, 12] = 1 + three_lens[17, 22] = 1 + three_lens[24, 8] = 1 + three_lens = scipy.ndimage.gaussian_filter(three_lens, sigma=0.8) + three_lens /= np.sum(three_lens) + return three_lens + +def load_four_lens_32(): + psf = np.zeros((32, 32)) + psf[22, 22] = 1 + psf[15, 10] = 1 + psf[5, 12] = 1 + psf[28, 8] = 1 + psf = scipy.ndimage.gaussian_filter(psf, sigma=1) # note that this one is sigma 1, for mnist it's sigma 0.8 + psf /= np.sum(psf) + return psf + +def load_diffuser_32(): + diffuser_psf = skimage.io.imread('psfs/diffuser_psf.png') + diffuser_psf = diffuser_psf[:,:,1] + diffuser_resize = diffuser_psf[200:500, 250:550] + diffuser_resize = resize(diffuser_resize, (100, 100), anti_aliasing=True) #resize(diffuser_psf, (28, 28)) + diffuser_region = diffuser_resize[:32, :32] + diffuser_region /= np.sum(diffuser_region) + return diffuser_region + + + +### 10/15/2023: Make new versions of the model functions that train with Datasets - first attempt failed + +# lenses with centralized positions for use in task-specific estimations +def load_single_lens_uniform(size=32): + one_lens = np.zeros((size, size)) + one_lens[16, 16] = 1 + one_lens = scipy.ndimage.gaussian_filter(one_lens, sigma=0.8) + one_lens /= np.sum(one_lens) + return one_lens + +def load_two_lens_uniform(size=32): + two_lens = np.zeros((size, size)) + two_lens[16, 16] = 1 + two_lens[7, 9] = 1 + two_lens = scipy.ndimage.gaussian_filter(two_lens, sigma=0.8) + two_lens /= np.sum(two_lens) + return two_lens + +def load_three_lens_uniform(size=32): + three_lens = np.zeros((size, size)) + three_lens[16, 16] = 1 + three_lens[7, 9] = 1 + three_lens[23, 21] = 1 + three_lens = scipy.ndimage.gaussian_filter(three_lens, sigma=0.8) + three_lens /= np.sum(three_lens) + return three_lens + +def load_four_lens_uniform(size=32): + four_lens = np.zeros((size, size)) + four_lens[16, 16] = 1 + four_lens[7, 9] = 1 + four_lens[23, 21] = 1 + four_lens[8, 24] = 1 + four_lens = scipy.ndimage.gaussian_filter(four_lens, sigma=0.8) + four_lens /= np.sum(four_lens) + return four_lens +def load_five_lens_uniform(size=32): + five_lens = np.zeros((size, size)) + five_lens[16, 16] = 1 + five_lens[7, 9] = 1 + five_lens[23, 21] = 1 + five_lens[8, 24] = 1 + five_lens[21, 5] = 1 + five_lens = scipy.ndimage.gaussian_filter(five_lens, sigma=0.8) + five_lens /= np.sum(five_lens) + return five_lens + + + +## 01/24/2024 new CNN that's slightly deeper +def current_testing_model(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1, max_epochs=50, patience=20): + # structure from https://www.kaggle.com/code/amyjang/tensorflow-cifar10-cnn-tutorial + + if seed_value == -1: + seed_val = np.random.randint(10, 1000) + tfk.utils.set_random_seed(seed_val) + else: + tfk.utils.set_random_seed(seed_value) + + model = tf.keras.models.Sequential([ + tf.keras.layers.Conv2D(32, kernel_size=5, padding='same', input_shape=(65, 65, 1), activation='relu'), + tf.keras.layers.Conv2D(32, kernel_size=5, activation='relu'), + tf.keras.layers.MaxPooling2D(), + tf.keras.layers.Dropout(0.25), + + tf.keras.layers.Conv2D(64, kernel_size=3, padding='same', activation='relu'), + tf.keras.layers.Conv2D(64, kernel_size=3, activation='relu'), + tf.keras.layers.MaxPooling2D(), + tf.keras.layers.Dropout(0.25), + + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(512, activation='relu'), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(10, activation='softmax'), + ]) + + model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy']) + + early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option + mode="min", patience=patience, + restore_best_weights=True, verbose=1) + print(model.optimizer.get_config()) + + history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), epochs=max_epochs, batch_size=32, callbacks=[early_stop]) #validation data is not test data + test_loss, test_acc = model.evaluate(test_data, test_labels) + return history, model, test_loss, test_acc + + + + +## 01/24/2024 new CNN that's slightly deeper +def momentum_testing_model(train_data, train_labels, test_data, test_labels, val_data, val_labels, seed_value=-1, max_epochs=50, patience=20): + # structure from https://www.kaggle.com/code/amyjang/tensorflow-cifar10-cnn-tutorial + # includes nesterov momentum feature, rather than regular momentum + if seed_value == -1: + seed_val = np.random.randint(10, 1000) + tfk.utils.set_random_seed(seed_val) + else: + tfk.utils.set_random_seed(seed_value) + + model = tf.keras.models.Sequential([ + tf.keras.layers.Conv2D(32, kernel_size=5, padding='same', input_shape=(65, 65, 1), activation='relu'), + tf.keras.layers.Conv2D(32, kernel_size=5, activation='relu'), + tf.keras.layers.MaxPooling2D(), + tf.keras.layers.Dropout(0.25), + + tf.keras.layers.Conv2D(64, kernel_size=3, padding='same', activation='relu'), + tf.keras.layers.Conv2D(64, kernel_size=3, activation='relu'), + tf.keras.layers.MaxPooling2D(), + tf.keras.layers.Dropout(0.25), + + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(512, activation='relu'), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(10, activation='softmax'), + ]) + + model.compile(optimizer=SGD(momentum=0.9, nesterov=True), loss='sparse_categorical_crossentropy', metrics=['accuracy']) + + early_stop = tfk.callbacks.EarlyStopping(monitor="val_loss", # add in an early stopping option + mode="min", patience=patience, + restore_best_weights=True, verbose=1) + + print(model.optimizer.get_config()) + + history = model.fit(train_data, train_labels, validation_data=(val_data, val_labels), epochs=max_epochs, batch_size=32, callbacks=[early_stop]) #validation data is not test data + test_loss, test_acc = model.evaluate(test_data, test_labels) + return history, model, test_loss, test_acc + + +# bootstrapping function +def compute_bootstraps(mses, psnrs, ssims, test_set_length, num_bootstraps=100): + bootstrap_mses = [] + bootstrap_psnrs = [] + bootstrap_ssims = [] + for bootstrap_idx in tqdm(range(num_bootstraps), desc='Bootstrapping to compute confidence interval'): + # select indices for sampling + bootstrap_indices = np.random.choice(test_set_length, test_set_length, replace=True) + # take the metric values at those indices + bootstrap_selected_mses = mses[bootstrap_indices] + bootstrap_selected_psnrs = psnrs[bootstrap_indices] + bootstrap_selected_ssims = ssims[bootstrap_indices] + # accumulate the mean of the selected metric values + bootstrap_mses.append(np.mean(bootstrap_selected_mses)) + bootstrap_psnrs.append(np.mean(bootstrap_selected_psnrs)) + bootstrap_ssims.append(np.mean(bootstrap_selected_ssims)) + bootstrap_mses = np.array(bootstrap_mses) + bootstrap_psnrs = np.array(bootstrap_psnrs) + bootstrap_ssims = np.array(bootstrap_ssims) + return bootstrap_mses, bootstrap_psnrs, bootstrap_ssims + +def compute_confidence_interval(list_of_items, confidence_interval=0.95): + # use this one, final version + assert confidence_interval > 0 and confidence_interval < 1 + mean_value = np.mean(list_of_items) + lower_bound = np.percentile(list_of_items, 50 * (1 - confidence_interval)) + upper_bound = np.percentile(list_of_items, 50 * (1 + confidence_interval)) + return mean_value, lower_bound, upper_bound +