Skip to content
This repository has been archived by the owner on Jul 29, 2023. It is now read-only.

Inference figures overlays #155

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 61 additions & 14 deletions micro_dl/plotting/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.transforms import Bbox
import natsort
import numpy as np
import os
Expand All @@ -22,9 +24,10 @@ def save_predicted_images(input_imgs,
clip_limits=1,
font_size=15):
"""
Save plots of predicted images to prediction-figures directory.
Save plots of predicted images to prediction-figures directory:
- Overlay of target & prediction
- Plot of input, target, prediction, and overlay of target & prediction
- Figure containing plots of inputs, target, prediction, and overlay of target & prediction, inputs & target,
inputs & prediction, metric values; for every predicted channel a separated figure is created.

:param np.ndarray input_imgs: input images [c,y,x]
:param np.ndarray target_img: target [y,x]
Expand All @@ -40,13 +43,14 @@ def save_predicted_images(input_imgs,
os.makedirs(output_dir, exist_ok=True)

n_rows = 2
n_cols = np.shape(input_imgs)[0] + 1
n_cols = 2 * np.shape(input_imgs)[0] + 1
fig, ax = plt.subplots(n_rows, n_cols, squeeze=False)
ax = ax.flatten()
for axs in ax:
axs.axis('off')
fig.set_size_inches((12, 5 * n_rows))
fig.set_size_inches((12, 4 * n_rows))
axis_count = 0

# add input images to plot
for c, input_img in enumerate(input_imgs):
input_imgs[c] = hist_clipping(
Expand All @@ -60,22 +64,30 @@ def save_predicted_images(input_imgs,
ax[axis_count].set_title('Input', fontsize=font_size)
axis_count += 1

# add target image to plot
# clip values of target & predicted image
cur_target_chan = hist_clipping(
target_img,
clip_limits,
100 - clip_limits,
)
ax_target = ax[axis_count].imshow(cur_target_chan, cmap='gray')
cur_pred_chan = hist_clipping(
pred_img,
clip_limits,
100 - clip_limits,
)
max_intensity = np.max([np.amax(cur_target_chan), np.amax(cur_pred_chan)])

# add target image to plot
ax_target = ax[axis_count].imshow(cur_target_chan, cmap='gray', vmin=0, vmax=max_intensity)
ax[axis_count].axis('off')
divider = make_axes_locatable(ax[axis_count])
cax = divider.append_axes('right', size='5%', pad=0.05)
cbar = plt.colorbar(ax_target, cax=cax, orientation='vertical')
cbar = plt.colorbar(ax_target, cax=cax, orientation='vertical') # range shown in colorbar
ax[axis_count].set_title('Target', fontsize=font_size)
axis_count += 1

# add prediction to plot
ax_img = ax[axis_count].imshow(pred_img, cmap='gray')
ax_img = ax[axis_count].imshow(cur_pred_chan, cmap='gray', vmin=0, vmax=max_intensity)
ax[axis_count].axis('off')
divider = make_axes_locatable(ax[axis_count])
cax = divider.append_axes('right', size='5%', pad=0.05)
Expand All @@ -85,23 +97,58 @@ def save_predicted_images(input_imgs,

# add overlay target - prediction
cur_target_8bit = convert_to_8bit(cur_target_chan)
cur_prediction_8bit = convert_to_8bit(pred_img)
cur_prediction_8bit = convert_to_8bit(cur_pred_chan)
cur_target_pred = np.stack([cur_target_8bit, cur_prediction_8bit,
cur_target_8bit], axis=2)

ax[axis_count].imshow(cur_target_pred)
ax[axis_count].set_title('Overlay', fontsize=font_size)
ax[axis_count].set_title('Overlay Target-Prediction', fontsize=font_size)
axis_count += 1

# add overlay input-target
for input_img_idx, input_img in enumerate(input_imgs, 1):
ax[axis_count].imshow(input_img, 'gray')
cmap_pink = LinearSegmentedColormap.from_list("cmap_pink", [(1, 1, 1, 0), 'm'])
ax[axis_count].imshow(cur_target_chan, cmap_pink, alpha=1)
ax[axis_count].set_title('Overlay Input-Target', fontsize=font_size)
extent = full_extent(ax[axis_count]).transformed(fig.dpi_scale_trans.inverted())
# save input - target overlay
fig.savefig(os.path.join(output_dir, '{}_overlay_pink{}.{}'.format(output_fname, input_img_idx, ext)), bbox_inches=extent)
axis_count += 1

# add overlay input-prediction
for input_img_idx, input_img in enumerate(input_imgs, 1):
cmap_green = LinearSegmentedColormap.from_list("cmap_green", [(1, 1, 1, 0), 'g'])
ax[axis_count].imshow(input_img, 'gray')
ax[axis_count].imshow(cur_pred_chan, cmap_green, alpha=1)
ax[axis_count].set_title('Overlay Input-Prediction', fontsize=font_size)
extent = full_extent(ax[axis_count]).transformed(fig.dpi_scale_trans.inverted())
# save input - prediction overlay
fig.savefig(os.path.join(output_dir, '{}_overlay_green{}.{}'.format(output_fname, input_img_idx, ext)), bbox_inches=extent)
axis_count += 1

# add metrics
if metric is not None:
for c, (metric_name, value) in enumerate(zip(list(metric.keys()), metric.values[0][0:-1]), 1):
plt.figtext(0.5, 0.001+c*0.015, metric_name + ": {:.4f}".format(value), ha="center", fontsize=12)
plt.figtext(0.5, 0.001+c*0.017, metric_name + ": {:.4f}".format(value), ha="center", fontsize=12)

fname = os.path.join(output_dir, '{}.{}'.format(output_fname, ext))
fig.savefig(fname, dpi=300, bbox_inches='tight')
plt.close(fig)
fname = os.path.join(output_dir, '{}_overlay.{}'.format(output_fname, ext))
cv2.imwrite(fname, cur_target_pred)
# save target-prediction overlays as separate images
cv2.imwrite(os.path.join(output_dir, '{}_overlay_target_prediction.{}'.format(output_fname, ext)), cur_target_pred)


def full_extent(ax):
"""
Get the full extent of an axes.
:param ax: Matplotlib subplot axes
:return bbox: Matplotlib bounding box [[xmin, ymin], [xmax, ymax]]
"""
ax.figure.canvas.draw()
items = ax.get_xticklabels() + ax.get_yticklabels()
items += [ax, ax.title]
bbox = Bbox.union([item.get_window_extent() for item in items])
return bbox


def convert_to_8bit(img):
Expand Down