Source code for lensless.recon.unrolled_fista

# #############################################################################
# unrolled_fista.py
# =================
# Authors :
# Yohann PERRON [yohann.perron@gmail.com]
# #############################################################################

import numpy as np
from lensless.recon.trainable_recon import TrainableReconstructionAlgorithm
from lensless.recon.gd import non_neg

try:
    import torch

except ImportError:
    raise ImportError("Pytorch is require to use trainable reconstruction algorithms.")


[docs]class UnrolledFISTA(TrainableReconstructionAlgorithm): """ Object for applying unrolled projected gradient descent with FISTA (Fast Iterative Shrinkage-Thresholding Algorithm) for acceleration. FISTA Paper: https://www.ceremade.dauphine.fr/~carlier/FISTA """
[docs] def __init__(self, psf, n_iter=5, dtype=None, proj=non_neg, learn_tk=True, tk=1, **kwargs): """ COnstructor for unrolled FISTA algorithm. Parameters ---------- psf : :py:class:`~torch.Tensor` Point spread function (PSF) that models forward propagation. Must be of shape (depth, height, width, channels) even if depth = 1 and channels = 1. You can use :py:func:`~lensless.io.load_psf` to load a PSF from a file such that it is in the correct format. n_iter : int, optional Number of iterations to unrolled, by default 5 dtype : float32 or float64 Data type to use for optimization. proj : :py:class:`function`, optional Projection function to apply at each iteration, by default non_neg learn_tk : bool, optional whether the tk parameters of FISTA should be learnt, by default True tk : int, optional Initial value of tk, by default 1 """ super(UnrolledFISTA, self).__init__(psf, n_iter=n_iter, dtype=dtype, reset=False, **kwargs) self._proj = proj # initial guess, half intensity image # for online approach could use last reconstruction psf_flat = self._psf.reshape(-1, self._psf_shape[3]) pixel_start = (torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values) / 2 self._image_init = torch.ones_like(self._psf[None, ...]) * pixel_start # learnable step size initialize as < 2 / lipschitz Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3]) H_flat = self._convolver._H.reshape(-1, self._psf_shape[3]) if not self.skip_unrolled: self._alpha_p = torch.nn.Parameter( torch.ones(self._n_iter, self._psf_shape[3]).to(psf.device) * (1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values) ) else: self._alpha_p = torch.ones(self._n_iter, self._psf_shape[3]).to(psf.device) * ( 1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values ) # set tk, can be learnt if learn_tk=True self._tk_p = [tk] for i in range(self._n_iter): self._tk_p.append((1 + np.sqrt(1 + 4 * self._tk_p[i] ** 2)) / 2) self._tk_p = torch.Tensor(self._tk_p) if learn_tk and not self.skip_unrolled: self._tk_p = torch.nn.Parameter(self._tk_p).to(psf.device) self.reset()
def _form_image(self): return self._proj(self._image_est) def _grad(self): diff = self._convolver.convolve(self._image_est) - self._data return self._convolver.deconvolve(diff) def reset(self, batch_size=1): if self._initial_est is not None: self._image_est = self._initial_est else: self._image_est = self._image_init.expand(batch_size, -1, -1, -1, -1) self._xk = self._image_est # enforce positivity self._alpha = torch.abs(self._alpha_p) self._tk = torch.abs(self._tk_p) def _update(self, iter): self._image_est = self._image_est - self._alpha[iter] * self._grad() xk = self._proj(self._image_est) self._image_est = xk + (self._tk[iter] - 1) / self._tk[iter + 1] * (xk - self._xk) self._xk = xk