Source code for lensless.utils.plot

# #############################################################################
# plot.py
# =================
# Authors :
# Eric BEZZAM [ebezzam@gmail.com]
# Julien SAHLI [julien.sahli@epfl.ch]
# #############################################################################


import numpy as np
import warnings
import matplotlib.pyplot as plt
import os
import json

from lensless.utils.image import FLOAT_DTYPES, get_max_val, gamma_correction, autocorr2d


[docs]def plot_image(img, ax=None, gamma=None, normalize=True): """ Plot image data. Parameters ---------- img : :py:class:`~numpy.ndarray` Data to plot. ax : :py:class:`~matplotlib.axes.Axes`, optional `Axes` object to fill for plotting/saving, default is to create one. gamma : float, optional Gamma correction factor to apply for plots. Default is None. normalize : bool, optional Whether to normalize data to maximum range. Default is True. Returns ------- ax : :py:class:`~matplotlib.axes.Axes` Axes on which image is plot. """ # if we have only 1 depth, remove the axis if img.shape[0] == 1: img = img[0] # if we have only 1 color channel, remove the axis if img.shape[-1] == 1: img = img[..., 0] disp_img = None cmap = None # full 3D RGB format : [depth, width, height, color] is_3d = False if len(img.shape) == 4: disp_img = [np.sum(img, axis=axis) for axis in range(3)] cmap = None is_3d = True # data of length 3 means we have to infer whichever depth or color is missing, based on shape. elif len(img.shape) == 3: if img.shape[2] == 3: # 2D rgb disp_img = [img] cmap = None else: # 3D grayscale disp_img = [np.sum(img, axis=axis) for axis in range(3)] cmap = "gray" is_3d = True # data of length 2 means we have only width and height elif len(img.shape) == 2: # 2D grayscale disp_img = [img] cmap = "gray" else: raise ValueError(f"Unexpected data shape : {img.shape}") max_val = [d.max() for d in disp_img] if not normalize: for i in range(len(max_val)): if disp_img[i].dtype not in FLOAT_DTYPES: max_val[i] = get_max_val(disp_img[i]) else: max_val[i] = 1 assert len(disp_img) == 1 or len(disp_img) == 3 # need float image for gamma correction and plotting img_norm = disp_img.copy() for i in range(len(img_norm)): img_norm[i] = disp_img[i] / max_val[i] if gamma and gamma > 1: img_norm[i] = gamma_correction(img_norm[i], gamma=gamma) if ax is None: if not is_3d: _, ax = plt.subplots() else: _, ax = plt.subplots(2, 2) else: if is_3d: assert len(ax) == 2 assert len(ax[0]) == 2 if len(img_norm) == 1: ax.imshow(img_norm[0], cmap=cmap) else: # plot each projection separately ax[0, 0].imshow(img_norm[0], cmap=cmap) ax[0, 1].imshow(np.swapaxes(img_norm[2], 0, 1), cmap=cmap) ax[1, 0].imshow(img_norm[1], cmap=cmap) ax[1, 1].axis("off") ax[0, 1].set_xlabel("Depth") ax[1, 0].set_ylabel("Depth") return ax
[docs]def pixel_histogram(img, nbits=None, ax=None, log_scale=True): """ Plot pixel value histogram. Parameters ---------- img : py:class:`~numpy.ndarray` 2D or 3D image. nbits : int, optional Bit-depth of camera data. ax : :py:class:`~matplotlib.axes.Axes`, optional `Axes` object to fill, default is to create one. log_scale : bool, optional Whether to use log scale in counting number of pixels. Return ------ ax : :py:class:`~matplotlib.axes.Axes` Axes on which histogram is plot. """ if ax is None: _, ax = plt.subplots() if nbits: # max_val = get_max_val(img, nbits) max_val = 2**nbits - 1 else: max_val = int(img.max()) if len(img.shape) == 3: # 3D image color_order = ("r", "g", "b") for i, col in enumerate(color_order): hist, bins = np.histogram(img[:, :, i].ravel(), bins=max_val, range=[0, max_val + 1]) ax.plot(hist, color=col) else: # 2D image vals = img.flatten() hist, bins = np.histogram(vals, bins=max_val, range=[0, max_val + 1]) ax.plot(hist, color="gray") ax.set_xlim([max_val - 1.1 * max_val, max_val * 1.1]) if log_scale: ax.set_yscale("log") ax.set_xlabel("Pixel value") ax.grid() return ax
[docs]def plot_cross_section( vals, idx=None, ax=None, dB=True, plot_db_drop=3, min_val=1e-4, max_val=None, plot_width=None, **kwargs, ): """ Plot cross-section of a 2-D image. Parameters ---------- vals : py:class:`~numpy.ndarray` 2-D image data. idx : int, optional Row for which to plot cross-section. Default is to take middle. ax : :py:class:`~matplotlib.axes.Axes`, optional `Axes` object to fill, default is to create one. dB : bool, optional Whether to plot in dB scale. Return ------ ax : :py:class:`~matplotlib.axes.Axes` Axes on which cross-section is plot. """ if ax is None: _, ax = plt.subplots() # get cross-section if idx is None: # if no index, take cross-section with maximum value max_idx = np.unravel_index(np.argmax(vals, axis=None), vals.shape) idx = max_idx[0] cross_section = vals[idx, :].astype(np.float32) # normalize if max_val is None: max_val = cross_section.max() cross_section /= max_val min_val = max(min_val, cross_section.min()) if dB: cross_section[cross_section < min_val] = min_val cross_section = 10 * np.log10(cross_section) min_val = 10 * np.log10(min_val) ax.set_ylabel("dB") x_vals = np.arange(len(cross_section)) x_vals -= np.argmax(cross_section) ax.plot(x_vals, cross_section, **kwargs) ax.set_ylim([min_val, 0]) if plot_width is not None: half_width = plot_width // 2 + 1 ax.set_xlim([-half_width, half_width]) ax.grid() ax.set_title("Cross-section") if dB and plot_db_drop: cross_section -= np.max(cross_section) zero_crossings = np.where(np.diff(np.signbit(cross_section + plot_db_drop)))[0] if len(zero_crossings) >= 2: zero_crossings -= np.argmax(cross_section) # width = zero_crossings[-1] - zero_crossings[0] # ax.axvline(x=zero_crossings[0], c="k", linestyle="--") # ax.axvline(x=zero_crossings[-1], c="k", linestyle="--") first_crossing = np.abs(zero_crossings[np.argmin(np.abs(zero_crossings))]) width = 2 * np.abs(first_crossing) ax.axvline(x=-first_crossing, c="k", linestyle="--") ax.axvline(x=+first_crossing, c="k", linestyle="--") print(f"-{plot_db_drop}dB width = {width} pixels") ax.set_title("Cross-section") ax.set_xlabel(f"-{plot_db_drop}dB width = {width}") else: warnings.warn( "Width could not be determined. Did not detect two -{} points : {}".format( plot_db_drop, zero_crossings ) ) return ax, cross_section
[docs]def plot_autocorr2d(vals, pad_mode="reflect", ax=None): """ Plot 2-D autocorrelation of image. Parameters ---------- vals : py:class:`~numpy.ndarray` 2-D image. pad_mode : str Desired padding. See NumPy documentation: https://numpy.org/doc/stable/reference/generated/numpy.pad.html ax : :py:class:`~matplotlib.axes.Axes`, optional `Axes` object to fill, default is to create one. Return ------ ax : :py:class:`~matplotlib.axes.Axes` Axes on which auto-correlation is plot. autocorr : py:class:`~numpy.ndarray` Auto-correlation. """ nbit_plot = 8 max_val_plot = 2**nbit_plot - 1 # compute autocorrelation autocorr = autocorr2d(vals, pad_mode=pad_mode) # rescale for plotting data = autocorr - np.min(autocorr) data = data / np.max(np.abs(data)) # normalize the data to 0 - 1 data = max_val_plot * data # Now scale by bit depth autocorr_img = data.astype(np.uint8) if ax is None: _, ax = plt.subplots() ax.imshow(autocorr_img, cmap="gray", vmin=0, vmax=max_val_plot) ax.axis("off") return ax, autocorr
[docs]def plot_autocorr_rgb(img, width=3, figsize=None, plot_psf=False, psf_gamma=2.2, verbose=False): """ Plot autocorrelation of each channel of an image. Parameters ---------- img : py:class:`~numpy.ndarray` 2-D image. width : int, optional Width of cross-section to plot. Default is 3dB. """ assert len(img.shape) == 3, "Image must be 3D" assert img.shape[2] == 3, "Image must have 3 color channels" if plot_psf: _, ax_auto = plt.subplots(ncols=3, nrows=3, num="Autocorrelations", figsize=figsize) else: _, ax_auto = plt.subplots(ncols=3, nrows=2, num="Autocorrelations", figsize=figsize) for i, c in enumerate(["Red", "Green", "Blue"]): if plot_psf: plot_image( img[:, :, i], ax=ax_auto[0][i], gamma=psf_gamma, normalize=True, ) # ax_auto[0][i].imshow(img[:, :, i], cmap="gray") ax_auto[0][i].axis("off") ax_auto[0][i].set_title(f"{c} PSF") # plot autocorrelation _, autocorr_c = plot_autocorr2d(img[:, :, i], ax=ax_auto[1 if plot_psf else 0][i]) ax_auto[1][i].set_title("Autocorrelation") # # -- horizontal cross-section max_idx = np.unravel_index(np.argmax(autocorr_c, axis=None), autocorr_c.shape) idx = max_idx[0] # ax_auto[1][i].axhline(y=idx, c=c, linestyle="--") ax, cross_section = plot_cross_section( autocorr_c, idx=idx, color=c, ax=ax_auto[2 if plot_psf else 1][i], plot_db_drop=width, ) if verbose: print(f"Maximum drop in {c} channel: {cross_section.max() - cross_section.min()}") if i != 0: ax.set_ylabel("") return ax
[docs]def compare_models(model_paths, max_epoch=None, linewidth=2, fontsize=18, metrics=None): """ Plot train and test loss for multiple models, and print metrics for best epoch. Parameters ---------- model_paths : dict Dictionary of model names and their paths. max_epoch : int, optional Maximum epoch to plot. Default is None. linewidth : int, optional Line width for plot. Default is 2. fontsize : int, optional Font size for plot. Default is 18. metrics : list, optional List of metrics to print. Default is ["PSNR", "SSIM", "LPIPS_Vgg"]. """ if metrics is None: metrics = ["PSNR", "SSIM", "LPIPS_Vgg"] # plot train and test loss import matplotlib.colors as mcolors plot_colors = list(mcolors.TABLEAU_COLORS.keys()) _, ax = plt.subplots() for model in model_paths: model_path = model_paths[model] _metrics_path = os.path.join(model_path, "metrics.json") assert os.path.exists(_metrics_path), f"Path {_metrics_path} does not exist" _test_metrics = json.load(open(_metrics_path)) color = plot_colors.pop() train_loss = np.array(_test_metrics["LOSS"]) if max_epoch is not None: train_loss = train_loss[: max_epoch + 1] ax.plot( train_loss, label=model + " (train)", color=color, linestyle="--", linewidth=linewidth ) test_loss = np.array(_test_metrics["MSE"]) + np.array(_test_metrics["LPIPS_Vgg"]) if max_epoch is not None: test_loss = test_loss[: max_epoch + 1] ax.plot(test_loss, label=model + " (test)", linestyle="-", color=color, linewidth=linewidth) # best_epoch = np.argmin(test_loss) best_epoch = _test_metrics["best_epoch"] print(f"\n-- {model} --") print(f"Best epoch for {model}: {best_epoch} / {len(test_loss)-1}") print(f"Best test loss for {model}: {test_loss[best_epoch]}") # print metrics for _metric in metrics: print(f"{_metric}: {np.array(_test_metrics[_metric])[best_epoch]:.3}") # set font size ax.tick_params(axis="both", which="major", labelsize=fontsize) ax.set_xlabel("Epoch", fontsize=fontsize) ax.set_title("Train-test loss", fontsize=fontsize) # legend outside ax.legend(loc="upper right", fontsize=fontsize) # ax.set_ylim([0.4, 1]); if max_epoch is not None: ax.set_xlim([0, max_epoch]) ax.grid() return ax