Source code for lensless.recon.tikhonov

# #############################################################################
# tikhonov.py
# =================
# Authors :
# Eric BEZZAM [ebezzam@gmail.com]
# Aaron FARGEON [aa.fargeon@gmail.com]
# #############################################################################

"""
Tikhonov
========

The py:class:`~lensless.recon.tikhonov.CodedApertureReconstruction` class is meant
to recover an image from a py:class:`~lensless.hardware.mask.CodedAperture` lensless
capture, using the analytical solution to the Tikhonov optimization problem
(least squares problem with L2 regularization term), as in the `FlatCam paper
<https://arxiv.org/abs/1509.00116>`_ (Eq. 7).
"""

import numpy as np
from numpy.linalg import multi_dot

try:
    import torch

    torch_available = True
except ImportError:
    torch_available = False


[docs]class CodedApertureReconstruction: """ Reconstruction method for the (non-iterative) Tikhonov algorithm presented in the `FlatCam paper <https://arxiv.org/abs/1509.00116>`_. TODO: operations in float32 """
[docs] def __init__(self, mask, image_shape, P=None, Q=None, lmbd=3e-4): """ Parameters ---------- mask : py:class:`lensless.hardware.mask.CodedAperture` Coded aperture mask object. image_shape : (`array-like` or `tuple`) The shape of the image to reconstruct. P : :py:class:`~numpy.ndarray`, optional Left convolution matrix in measurement operator. Must be of shape (measurement_resolution[0], image_shape[0]). By default, it is generated from the mask. In practice, it may be useful to measure as in the FlatCam paper. Q : :py:class:`~numpy.ndarray`, optional Right convolution matrix in measurement operator. Must be of shape (measurement_resolution[1], image_shape[1]). By default, it is generated from the mask. In practice, it may be useful to measure as in the FlatCam paper. lmbd: float: Regularization parameter. Default value is `3e-4` as in the FlatCam paper `code <https://github.com/tanjasper/flatcam/blob/master/python/demo.py>`_. """ self.lmbd = lmbd if P is None or Q is None: self.P, self.Q = mask.get_conv_matrices(image_shape) else: self.P = P self.Q = Q assert self.P.shape == ( mask.resolution[0], image_shape[0], ), "Left matrix P shape mismatch" assert self.Q.shape == ( mask.resolution[1], image_shape[1], ), "Right matrix Q shape mismatch"
[docs] def apply(self, img): """ Method for performing Tikhinov reconstruction. Parameters ---------- img : :py:class:`~numpy.ndarray` or :py:class:`torch.Tensor` Lensless capture measurement. Must be 3D even if grayscale. Returns ------- :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor` Reconstructed image, in the same format as the measurement. """ assert ( len(img.shape) == 3 ), "Object should be a 3D array or tensor (HxWxC) even if grayscale." if torch_available and isinstance(img, torch.Tensor): # Empty matrix for reconstruction n_channels = img.shape[-1] x_est = torch.empty([self.P.shape[1], self.Q.shape[1], n_channels]) self.P = torch.from_numpy(self.P).float() self.Q = torch.from_numpy(self.Q).float() # Applying reconstruction for each channel for c in range(n_channels): Yc = img[:, :, c] # SVD of left matrix UL, SL, VLh = torch.linalg.svd(self.P) VL = VLh.T DL = torch.cat( ( torch.diag(SL), torch.zeros([self.P.shape[0] - SL.size(0), SL.size(0)], device=SL.device), ) ) singLsq = SL**2 # SVD of right matrix UR, SR, VRh = torch.linalg.svd(self.Q) VR = VRh.T DR = torch.cat( ( torch.diag(SR), torch.zeros([self.Q.shape[0] - SR.size(0), SR.size(0)], device=SR.device), ) ) singRsq = SR**2 # Applying analytical reconstruction inner = torch.linalg.multi_dot([DL.T, UL.T, Yc, UR, DR]) / ( torch.outer(singLsq, singRsq) + torch.full(x_est.shape[0:2], self.lmbd) ) x_est[:, :, c] = torch.linalg.multi_dot([VL, inner, VR.T]) # Non-negativity constraint: setting all negative values to 0 x_est = torch.clamp(x_est, min=0) # Normalizing the image x_est = (x_est - x_est.min()) / (x_est.max() - x_est.min()) else: # Empty matrix for reconstruction n_channels = img.shape[-1] x_est = np.empty([self.P.shape[1], self.Q.shape[1], n_channels]) # Applying reconstruction for each channel for c in range(n_channels): # SVD of left matrix UL, SL, VLh = np.linalg.svd(self.P, full_matrices=True) VL = VLh.T DL = np.concatenate((np.diag(SL), np.zeros([self.P.shape[0] - SL.size, SL.size]))) singLsq = np.square(SL) # SVD of right matrix UR, SR, VRh = np.linalg.svd(self.Q, full_matrices=True) VR = VRh.T DR = np.concatenate((np.diag(SR), np.zeros([self.Q.shape[0] - SR.size, SR.size]))) singRsq = np.square(SR) # Applying analytical reconstruction Yc = img[:, :, c] inner = multi_dot([DL.T, UL.T, Yc, UR, DR]) / ( np.outer(singLsq, singRsq) + np.full(x_est.shape[0:2], self.lmbd) ) x_est[:, :, c] = multi_dot([VL, inner, VR.T]) # Non-negativity constraint: setting all negative values to 0 x_est = x_est.clip(min=0) # Normalizing the image x_est = (x_est - x_est.min()) / (x_est.max() - x_est.min()) return x_est