Source code for lensless.recon.unrolled_admm

# #############################################################################
# unrolled_admm.py
# =================
# Authors :
# Yohann PERRON [yohann.perron@gmail.com]
# Eric BEZZAM [ebezzam@gmail.com]
# #############################################################################

from lensless.recon.trainable_recon import TrainableReconstructionAlgorithm
from lensless.recon.admm import soft_thresh, finite_diff, finite_diff_adj, finite_diff_gram


try:
    import torch

except ImportError:
    raise ImportError("Pytorch is require to use trainable reconstruction algorithms.")


[docs]class UnrolledADMM(TrainableReconstructionAlgorithm): """ Object for applying unrolled version of :py:class:`~lensless.ADMM`. """
[docs] def __init__( self, psf, dtype=None, n_iter=5, mu1=1e-6, mu2=1e-5, mu3=4e-5, tau=0.0001, psi=None, psi_adj=None, psi_gram=None, pad=False, norm="backward", **kwargs, ): """ Parameters ---------- psf : :py:class:`~torch.Tensor` Point spread function (PSF) that models forward propagation. Must be of shape (depth, height, width, channels) even if depth = 1 and channels = 1. You can use :py:func:`~lensless.io.load_psf` to load a PSF from a file such that it is in the correct format. dtype : float32 or float64 Data type to use for optimization. Default is float32. n_iter : int, optional Number of iterations to unrolled, by default 5 mu1 : float Initial step size for updating primal/dual variables. mu2 : float Initial step size for updating primal/dual variables. mu3 : float Initial step size for updating primal/dual variables. tau : float Initial weight for L1 norm of `psi` applied to the image estimate. psi : :py:class:`function`, optional Operator to map image to a space that the image is assumed to be sparse in (hence L1 norm). Default is to use total variation (TV) operator. psi_adj : :py:class:`function` Adjoint of `psi`. psi_gram : :py:class:`function` Function to compute gram of `psi`. pad : bool, optional Whether to pad the image with zeros before applying the PSF. norm : str Normalization to use for the convolution. Options are "forward", "backward", and "ortho". Default is "backward". """ super(UnrolledADMM, self).__init__( psf, n_iter=n_iter, dtype=dtype, pad=pad, norm=norm, reset=False, **kwargs ) if not self.skip_unrolled: self._mu1_p = torch.nn.Parameter( torch.ones(self._n_iter, device=self._psf.device) * mu1 ) self._mu2_p = torch.nn.Parameter( torch.ones(self._n_iter, device=self._psf.device) * mu2 ) self._mu3_p = torch.nn.Parameter( torch.ones(self._n_iter, device=self._psf.device) * mu3 ) self._tau_p = torch.nn.Parameter( torch.ones(self._n_iter, device=self._psf.device) * tau ) else: self._mu1_p = torch.ones(self._n_iter, device=self._psf.device) * mu1 self._mu2_p = torch.ones(self._n_iter, device=self._psf.device) * mu2 self._mu3_p = torch.ones(self._n_iter, device=self._psf.device) * mu3 self._tau_p = torch.ones(self._n_iter, device=self._psf.device) * tau # set prior if psi is None: # use already defined Psi and PsiT self._PsiTPsi = finite_diff_gram(self._padded_shape, self._dtype, self.is_torch) else: assert psi_adj is not None assert psi_gram is not None assert callable(psi) assert callable(psi_adj) assert callable(psi_gram) # overwrite already defined Psi and PsiT self._Psi = psi self._PsiT = psi_adj self._PsiTPsi = psi_gram(self._padded_shape) self._PsiTPsi = self._PsiTPsi.to(self._psf.device) self.reset()
def _Psi(self, x): """ Operator to map image to space that the image is assumed to be sparse in. """ return finite_diff(x) def _PsiT(self, U): """ Adjoint of `_Psi`. """ return finite_diff_adj(U) def reset(self, batch_size=1): if self._data is not None: device = self._data.device else: device = self._convolver._H.device # ensure that mu1, mu2, mu3, tau are positive self._mu1 = torch.abs(self._mu1_p).to(device) self._mu2 = torch.abs(self._mu2_p).to(device) self._mu3 = torch.abs(self._mu3_p).to(device) self._tau = torch.abs(self._tau_p).to(device) # TODO initialize without padding if self._initial_est is not None: self._image_est = self._initial_est.to(device) else: self._image_est = torch.zeros([1] + self._padded_shape, dtype=self._dtype).to(device) self._X = torch.zeros_like(self._image_est) self._U = torch.zeros_like(self._Psi(self._image_est)) self._W = torch.zeros_like(self._X) if self._image_est.max(): # if non-zero self._forward_out = self._convolver.convolve(self._image_est) self._Psi_out = self._Psi(self._image_est) else: self._forward_out = torch.zeros_like(self._X) self._Psi_out = torch.zeros_like(self._U) self._xi = torch.zeros_like(self._image_est) self._eta = torch.zeros_like(self._U) self._rho = torch.zeros_like(self._X) # precompute_R_divmat [iter, batch, depth, height, width, channels] self._R_divmat = 1.0 / ( self._mu1[:, None, None, None, None, None] * (torch.abs(self._convolver._Hadj * self._convolver._H))[None, ...] + self._mu2[:, None, None, None, None, None] * torch.abs(self._PsiTPsi).to(device) + self._mu3[:, None, None, None, None, None] ).type(self._complex_dtype) # precompute_X_divmat [iter, batch, depth, height, width, channels] self._X_divmat = 1.0 / ( self._convolver._pad(torch.ones_like(self._convolver._psf))[None, ...] + self._mu1[:, None, None, None, None, None] ) def _U_update(self, iter): """Total variation update.""" # to avoid computing sparse operator twice self._U = soft_thresh( self._Psi_out + self._eta / self._mu2[iter], self._tau[iter] / self._mu2[iter] ) def _X_update(self, iter): # to avoid computing forward model twice self._X = self._X_divmat[iter] * ( self._xi + self._mu1[iter] * self._forward_out + self._convolver._pad(self._data) ) def _image_update(self, iter): rk = ( (self._mu3[iter] * self._W - self._rho) + self._PsiT(self._mu2[iter] * self._U - self._eta) + self._convolver.deconvolve(self._mu1[iter] * self._X - self._xi) ) freq_space_result = self._R_divmat[iter] * torch.fft.rfft2(rk, dim=(-3, -2)) self._image_est = torch.fft.irfft2( freq_space_result, dim=(-3, -2), s=self._convolver._padded_shape[-3:-1] ) def _W_update(self, iter): """Non-negativity update""" self._W = torch.maximum( self._rho / self._mu3[iter] + self._image_est, torch.zeros_like(self._image_est) ) def _xi_update(self, iter): # to avoid computing forward model twice self._xi = self._xi + self._mu1[iter] * (self._forward_out - self._X) def _eta_update(self, iter): # to avoid finite difference operataion again? self._eta = self._eta + self._mu2[iter] * (self._Psi_out - self._U) def _rho_update(self, iter): self._rho = self._rho + self._mu3[iter] * (self._image_est - self._W) def _update(self, iter): self._U_update(iter) self._X_update(iter) self._W_update(iter) self._image_update(iter) # update forward and sparse operators self._forward_out = self._convolver.convolve(self._image_est) self._Psi_out = self._Psi(self._image_est) self._xi_update(iter) self._eta_update(iter) self._rho_update(iter) def _form_image(self): image = self._convolver._crop(self._image_est) # image = torch.clamp(image, min=0) image = torch.clip(image, min=0.0) return image