diff --git a/micro_dl/plotting/plot_utils.py b/micro_dl/plotting/plot_utils.py index a690d309..68d2c4c1 100644 --- a/micro_dl/plotting/plot_utils.py +++ b/micro_dl/plotting/plot_utils.py @@ -4,6 +4,8 @@ import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt +from matplotlib.colors import LinearSegmentedColormap +from matplotlib.transforms import Bbox import natsort import numpy as np import os @@ -22,9 +24,10 @@ def save_predicted_images(input_imgs, clip_limits=1, font_size=15): """ - Save plots of predicted images to prediction-figures directory. + Save plots of predicted images to prediction-figures directory: - Overlay of target & prediction - - Plot of input, target, prediction, and overlay of target & prediction + - Figure containing plots of inputs, target, prediction, and overlay of target & prediction, inputs & target, + inputs & prediction, metric values; for every predicted channel a separated figure is created. :param np.ndarray input_imgs: input images [c,y,x] :param np.ndarray target_img: target [y,x] @@ -40,13 +43,14 @@ def save_predicted_images(input_imgs, os.makedirs(output_dir, exist_ok=True) n_rows = 2 - n_cols = np.shape(input_imgs)[0] + 1 + n_cols = 2 * np.shape(input_imgs)[0] + 1 fig, ax = plt.subplots(n_rows, n_cols, squeeze=False) ax = ax.flatten() for axs in ax: axs.axis('off') - fig.set_size_inches((12, 5 * n_rows)) + fig.set_size_inches((12, 4 * n_rows)) axis_count = 0 + # add input images to plot for c, input_img in enumerate(input_imgs): input_imgs[c] = hist_clipping( @@ -60,22 +64,30 @@ def save_predicted_images(input_imgs, ax[axis_count].set_title('Input', fontsize=font_size) axis_count += 1 - # add target image to plot + # clip values of target & predicted image cur_target_chan = hist_clipping( target_img, clip_limits, 100 - clip_limits, ) - ax_target = ax[axis_count].imshow(cur_target_chan, cmap='gray') + cur_pred_chan = hist_clipping( + pred_img, + clip_limits, + 100 - clip_limits, + ) + max_intensity = np.max([np.amax(cur_target_chan), np.amax(cur_pred_chan)]) + + # add target image to plot + ax_target = ax[axis_count].imshow(cur_target_chan, cmap='gray', vmin=0, vmax=max_intensity) ax[axis_count].axis('off') divider = make_axes_locatable(ax[axis_count]) cax = divider.append_axes('right', size='5%', pad=0.05) - cbar = plt.colorbar(ax_target, cax=cax, orientation='vertical') + cbar = plt.colorbar(ax_target, cax=cax, orientation='vertical') # range shown in colorbar ax[axis_count].set_title('Target', fontsize=font_size) axis_count += 1 # add prediction to plot - ax_img = ax[axis_count].imshow(pred_img, cmap='gray') + ax_img = ax[axis_count].imshow(cur_pred_chan, cmap='gray', vmin=0, vmax=max_intensity) ax[axis_count].axis('off') divider = make_axes_locatable(ax[axis_count]) cax = divider.append_axes('right', size='5%', pad=0.05) @@ -85,23 +97,58 @@ def save_predicted_images(input_imgs, # add overlay target - prediction cur_target_8bit = convert_to_8bit(cur_target_chan) - cur_prediction_8bit = convert_to_8bit(pred_img) + cur_prediction_8bit = convert_to_8bit(cur_pred_chan) cur_target_pred = np.stack([cur_target_8bit, cur_prediction_8bit, cur_target_8bit], axis=2) - ax[axis_count].imshow(cur_target_pred) - ax[axis_count].set_title('Overlay', fontsize=font_size) + ax[axis_count].set_title('Overlay Target-Prediction', fontsize=font_size) axis_count += 1 + + # add overlay input-target + for input_img_idx, input_img in enumerate(input_imgs, 1): + ax[axis_count].imshow(input_img, 'gray') + cmap_pink = LinearSegmentedColormap.from_list("cmap_pink", [(1, 1, 1, 0), 'm']) + ax[axis_count].imshow(cur_target_chan, cmap_pink, alpha=1) + ax[axis_count].set_title('Overlay Input-Target', fontsize=font_size) + extent = full_extent(ax[axis_count]).transformed(fig.dpi_scale_trans.inverted()) + # save input - target overlay + fig.savefig(os.path.join(output_dir, '{}_overlay_pink{}.{}'.format(output_fname, input_img_idx, ext)), bbox_inches=extent) + axis_count += 1 + + # add overlay input-prediction + for input_img_idx, input_img in enumerate(input_imgs, 1): + cmap_green = LinearSegmentedColormap.from_list("cmap_green", [(1, 1, 1, 0), 'g']) + ax[axis_count].imshow(input_img, 'gray') + ax[axis_count].imshow(cur_pred_chan, cmap_green, alpha=1) + ax[axis_count].set_title('Overlay Input-Prediction', fontsize=font_size) + extent = full_extent(ax[axis_count]).transformed(fig.dpi_scale_trans.inverted()) + # save input - prediction overlay + fig.savefig(os.path.join(output_dir, '{}_overlay_green{}.{}'.format(output_fname, input_img_idx, ext)), bbox_inches=extent) + axis_count += 1 + # add metrics if metric is not None: for c, (metric_name, value) in enumerate(zip(list(metric.keys()), metric.values[0][0:-1]), 1): - plt.figtext(0.5, 0.001+c*0.015, metric_name + ": {:.4f}".format(value), ha="center", fontsize=12) + plt.figtext(0.5, 0.001+c*0.017, metric_name + ": {:.4f}".format(value), ha="center", fontsize=12) fname = os.path.join(output_dir, '{}.{}'.format(output_fname, ext)) fig.savefig(fname, dpi=300, bbox_inches='tight') plt.close(fig) - fname = os.path.join(output_dir, '{}_overlay.{}'.format(output_fname, ext)) - cv2.imwrite(fname, cur_target_pred) + # save target-prediction overlays as separate images + cv2.imwrite(os.path.join(output_dir, '{}_overlay_target_prediction.{}'.format(output_fname, ext)), cur_target_pred) + + +def full_extent(ax): + """ + Get the full extent of an axes. + :param ax: Matplotlib subplot axes + :return bbox: Matplotlib bounding box [[xmin, ymin], [xmax, ymax]] + """ + ax.figure.canvas.draw() + items = ax.get_xticklabels() + ax.get_yticklabels() + items += [ax, ax.title] + bbox = Bbox.union([item.get_window_extent() for item in items]) + return bbox def convert_to_8bit(img):