Source code for lensless.recon.gd

# #############################################################################
# gradient_descent.py
# =================
# Authors :
# Eric BEZZAM [ebezzam@gmail.com]
# Julien SAHLI [julien.sahli@epfl.ch]
# #############################################################################


import numpy as np
from lensless.recon.recon import ReconstructionAlgorithm
import inspect
from lensless.utils.io import load_data
import time

try:
    import torch

    torch_available = True
except ImportError:
    torch_available = False


class GradientDescentUpdate:
    """Gradient descent update techniques."""

    VANILLA = "vanilla"
    NESTEROV = "nesterov"
    FISTA = "fista"

    @staticmethod
    def all_values():
        vals = []
        for i in inspect.getmembers(GradientDescentUpdate):
            # remove private and protected functions, and this function
            if not i[0].startswith("_") and not callable(i[1]):
                vals.append(i[1])
        return vals


[docs]def non_neg(xi): """ Clip input so that it is non-negative. Parameters ---------- xi : :py:class:`~numpy.ndarray` Data to clip. Returns ------- nonneg : :py:class:`~numpy.ndarray` Non-negative projection of input. """ if torch_available and isinstance(xi, torch.Tensor): return torch.maximum(xi, torch.zeros_like(xi)) else: return np.maximum(xi, 0)
[docs]class GradientDescent(ReconstructionAlgorithm): """ Object for applying projected gradient descent. """
[docs] def __init__(self, psf, dtype=None, proj=non_neg, lip_fact=1.8, **kwargs): """ Parameters ---------- psf : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` Point spread function (PSF) that models forward propagation. Must be of shape (depth, height, width, channels) even if depth = 1 and channels = 1. You can use :py:func:`~lensless.io.load_psf` to load a PSF from a file such that it is in the correct format. dtype : float32 or float64 Data type to use for optimization. Default is float32. proj : :py:class:`function` Projection function to apply at each iteration. Default is non-negative. """ assert callable(proj) self._proj = proj self._lip_fact = lip_fact super(GradientDescent, self).__init__(psf, dtype, **kwargs) if self._denoiser is not None: print("Using denoiser in gradient descent.") # redefine projection function self._proj = self._denoiser
def reset(self): if self.is_torch: if self._initial_est is not None: self._image_est = self._initial_est else: # initial guess, half intensity image psf_flat = self._psf.reshape(-1, self._psf_shape[3]) pixel_start = ( torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values ) / 2 # initialize image estimate as [Batch, Depth, Height, Width, Channels] self._image_est = torch.ones_like(self._psf[None, ...]) * pixel_start # set step size as < 2 / lipschitz Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3]) H_flat = self._convolver._H.reshape(-1, self._psf_shape[3]) self._alpha = torch.real( self._lip_fact / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values ) else: if self._initial_est is not None: self._image_est = self._initial_est else: psf_flat = self._psf.reshape(-1, self._psf_shape[3]) pixel_start = (np.max(psf_flat, axis=0) + np.min(psf_flat, axis=0)) / 2 # initialize image estimate as [Batch, Depth, Height, Width, Channels] self._image_est = np.ones_like(self._psf[None, ...]) * pixel_start # set step size as < 2 / lipschitz Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3]) H_flat = self._convolver._H.reshape(-1, self._psf_shape[3]) self._alpha = np.real(self._lip_fact / np.max(Hadj_flat * H_flat, axis=0)) def _grad(self): diff = self._convolver.convolve(self._image_est) - self._data return self._convolver.deconvolve(diff) def _update(self, iter): self._image_est -= self._alpha * self._grad() self._image_est = self._form_image() def _form_image(self): if self._denoiser is not None: return self._proj(self._image_est, self._denoiser_noise_level) else: return self._proj(self._image_est)
[docs]class NesterovGradientDescent(GradientDescent): """ Object for applying projected gradient descent with Nesterov momentum for acceleration. A nice tutorial/ blog post on Nesterov momentum can be found `here <https://machinelearningmastery.com/gradient-descent-with-nesterov-momentum-from-scratch/>`_. """
[docs] def __init__(self, psf, dtype=None, proj=non_neg, p=0, mu=0.9, **kwargs): """ Parameters ---------- psf : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` Point spread function (PSF) that models forward propagation. Must be of shape (depth, height, width, channels) even if depth = 1 and channels = 1. You can use :py:func:`~lensless.io.load_psf` to load a PSF from a file such that it is in the correct format. dtype : float32 or float64 Data type to use for optimization. Default is float32. proj : :py:class:`function` Projection function to apply at each iteration. Default is non-negative. p : float Momentum parameter that keeps track of changes. By default, this is initialized to 0. mu : float Momentum parameter. Default is 0.9. """ self._p = p self._mu = mu super(NesterovGradientDescent, self).__init__(psf, dtype, proj, **kwargs)
def reset(self, p=0, mu=0.9): self._p = p self._mu = mu super(NesterovGradientDescent, self).reset() def _update(self, iter): p_prev = self._p self._p = self._mu * self._p - self._alpha * self._grad() self._image_est += -self._mu * p_prev + (1 + self._mu) * self._p # self._image_est = self._proj(self._image_est) self._image_est = self._form_image()
[docs]class FISTA(GradientDescent): """ Object for applying projected gradient descent with FISTA (Fast Iterative Shrinkage-Thresholding Algorithm) for acceleration. Paper: https://www.ceremade.dauphine.fr/~carlier/FISTA """
[docs] def __init__(self, psf, dtype=None, proj=non_neg, tk=1.0, **kwargs): """ Parameters ---------- psf : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` Point spread function (PSF) that models forward propagation. Must be of shape (depth, height, width, channels) even if depth = 1 and channels = 1. You can use :py:func:`~lensless.io.load_psf` to load a PSF from a file such that it is in the correct format. dtype : float32 or float64 Data type to use for optimization. Default is float32. proj : :py:class:`function` Projection function to apply at each iteration. Default is non-negative. tk : float Initial step size parameter for FISTA. It is updated at each iteration according to Eq. 4.2 of paper. By default, initialized to 1.0. """ self._initial_tk = tk super(FISTA, self).__init__(psf, dtype, proj, **kwargs) self._tk = tk self._xk = self._image_est
def reset(self, tk=None): super(FISTA, self).reset() if tk: self._tk = tk else: self._tk = self._initial_tk self._xk = self._image_est def _update(self, iter): self._image_est -= self._alpha * self._grad() xk = self._form_image() tk = (1 + np.sqrt(1 + 4 * self._tk**2)) / 2 self._image_est = xk + (self._tk - 1) / tk * (xk - self._xk) self._tk = tk self._xk = xk
def apply_gradient_descent(psf_fp, data_fp, n_iter, verbose=False, proj=non_neg, **kwargs): # load data psf, data = load_data(psf_fp=psf_fp, data_fp=data_fp, plot=False, **kwargs) # create reconstruction object recon = GradientDescent(psf, n_iter=n_iter, proj=proj) # set data recon.set_data(data) # perform reconstruction start_time = time.time() res = recon.apply(plot=False) proc_time = time.time() - start_time if verbose: print(f"Reconstruction time : {proc_time} s") print(f"Reconstruction shape: {res.shape}") return res