Source code for lensless.recon.trainable_recon

# #############################################################################
# trainable_recon.py
# ==================
# Authors :
# Yohann PERRON [yohann.perron@gmail.com]
# Eric BEZZAM [ebezzam@gmail.com]
# #############################################################################

import pathlib as plib
from matplotlib import pyplot as plt
from lensless.recon.recon import ReconstructionAlgorithm
from lensless.utils.plot import plot_image
from lensless.recon.rfft_convolve import RealFFTConvolve2D

try:
    import torch

except ImportError:
    raise ImportError("Pytorch is require to use trainable reconstruction algorithms.")


[docs]class TrainableReconstructionAlgorithm(ReconstructionAlgorithm, torch.nn.Module): """ Abstract class for defining lensless imaging reconstruction algorithms with trainable parameters. The following abstract methods need to be defined: * ``_update``: updating state variables at each iterations. * ``reset``: reset state variables. * ``_form_image``: any pre-processing that needs to be done in order to view the image estimate, e.g. reshaping or clipping. One advantage of deriving from this abstract class is that functionality for iterating, saving, and visualization is already implemented, namely in the ``apply`` method. Consequently, using a reconstruction algorithm that derives from it boils down to four steps: 1. Creating an instance of the reconstruction algorithm. 2. Training the algorithm 3. Setting the data. 4. Applying the algorithm. """
[docs] def __init__( self, psf, dtype=None, n_iter=1, pre_process=None, post_process=None, skip_unrolled=False, skip_pre=False, skip_post=False, return_intermediate=False, legacy_denoiser=False, compensation=None, compensation_residual=True, psf_network=None, psf_residual=True, # background subtraction direct_background_subtraction=False, background_network=None, integrated_background_subtraction=False, **kwargs, ): """ Base constructor. Derived constructor may define new state variables here and also reset them in `reset`. Parameters ---------- psf : :py:class:`~torch.Tensor` Point spread function (PSF) that models forward propagation. Must be of shape (depth, height, width, channels) even if depth = 1 and channels = 1. You can use :py:func:`~lensless.io.load_psf` to load a PSF from a file such that it is in the correct format. dtype : float32 or float64 Data type to use for optimization. n_iter : int Number of iterations for unrolled algorithm. pre_process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional If :py:class:`function` : Function to apply to the image estimate before algorithm. Its input most be (image to process, noise_level), where noise_level is a learnable parameter. If it include aditional learnable parameters, they will not be added to the parameter list of the algorithm. To allow for traning, the function must be autograd compatible. If :py:class:`~torch.nn.Module` : A DruNet compatible network to apply to the image estimate before algorithm. See ``utils.image.apply_denoiser`` for more details. The network will be included as a submodule of the algorithm and its parameters will be added to the parameter list of the algorithm. If this isn't intended behavior, set requires_grad=False. post_process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional If :py:class:`function` : Function to apply to the image estimate after the whole algorithm. Its input most be (image to process, noise_level), where noise_level is a learnable parameter. If it include aditional learnable parameters, they will not be added to the parameter list of the algorithm. To allow for traning, the function must be autograd compatible. If :py:class:`~torch.nn.Module` : A DruNet compatible network to apply to the image estimate after the whole algorithm. See ``utils.image.apply_denoiser`` for more details. The network will be included as a submodule of the algorithm and its parameters will be added to the parameter list of the algorithm. If this isn't intended behavior, set requires_grad=False. skip_unrolled : bool, optional Whether to skip the unrolled algorithm and only apply the pre- or post-processor block (e.g. to just use a U-Net for reconstruction). return_unrolled_output : bool, optional Whether to return the output of the unrolled algorithm if also using a post-processor block. compensation : list, optional Number of channels for each intermediate output in compensation layer, as in "Robust Reconstruction With Deep Learning to Handle Model Mismatch in Lensless Imaging" (2021). Post-processor must be defined if compensation provided. compensation_residual : bool, optional Whether to use residual connection in compensation layer. psf_network : :py:class:`function` or :py:class:`~torch.nn.Module`, optional Function or model to apply to PSF prior to camera inversion. psf_residual : bool, optional Whether to use residual connection in PSF network. """ assert isinstance(psf, torch.Tensor), "PSF must be a torch.Tensor" super(TrainableReconstructionAlgorithm, self).__init__( psf, dtype=dtype, n_iter=n_iter, **kwargs ) self._legacy_denoiser = legacy_denoiser self.input_background = False if pre_process is not None: self.set_pre_process(pre_process) else: self.pre_process = None self.pre_process_model = None self.pre_process_param = None if post_process is not None: self.set_post_process(post_process) else: self.post_process = None self.post_process_model = None self.post_process_param = None self.skip_unrolled = skip_unrolled self.skip_pre = skip_pre self.skip_post = skip_post # background subtraction self.direct_background_subtraction = direct_background_subtraction self.integrated_background_subtraction = integrated_background_subtraction self.learned_background_subtraction = False self.background_network = None if background_network is not None: if integrated_background_subtraction: assert ( direct_background_subtraction is False ), "Cannot use direct_background_subtraction and background_network at the same time." assert ( pre_process is None ), "Cannot use pre_process with integrated_background_subtraction." self.pre_process = background_network else: self.learned_background_subtraction = True assert ( direct_background_subtraction is False ), "Cannot use direct_background_subtraction and background_network at the same time." self.set_background_network(background_network) # PSF network self.psf_network = None self.psf_residual = psf_residual if psf_network is not None: self.set_psf_network(psf_network) # compensation branch self.return_intermediate = return_intermediate self.compensation_branch = compensation if compensation is not None: from lensless.recon.utils import CompensationBranch assert ( post_process is not None ), "If compensation_branch is True, post_process must be defined." assert ( len(compensation) == n_iter ), "compensation_nc must have the same length as n_iter" self.compensation_branch = CompensationBranch( compensation, residual=compensation_residual ) self.compensation_branch = self.compensation_branch.to(self._psf.device) if self.return_intermediate: assert ( post_process is not None or pre_process is not None ), "If return_intermediate is True, post_process or pre_process must be defined." if self.skip_unrolled: assert ( post_process is not None or pre_process is not None ), "If skip_unrolled is True, pre_process or post_process must be defined."
def _prepare_process_block(self, process): """ Method for preparing the pre or post process block. Parameters ---------- process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional Pre or post process block to prepare. """ if isinstance(process, torch.nn.Module): # If the post_process is a torch module, we assume it is a DruNet like network. from lensless.recon.utils import get_drunet_function, get_drunet_function_v2 from lensless.recon.restormer import get_restormer_function process_model = process # handle DataParallel if isinstance(process, torch.nn.DataParallel): process_type = process.module.__class__.__name__ else: process_type = process.__class__.__name__ if process_type == "Restormer": process_function = get_restormer_function(process_model) elif self._legacy_denoiser: process_function = get_drunet_function(process_model, mode="train") else: process_function = get_drunet_function_v2(process_model, mode="train") elif process is not None: # Otherwise, we assume it is a function. assert callable(process), "pre_process must be a callable function" process_function = process process_model = None else: process_function = None process_model = None if process_function is not None: process_param = torch.nn.Parameter(torch.tensor([1.0], device=self._psf.device)) else: process_param = None return process_function, process_model, process_param def set_pre_process(self, pre_process): if isinstance(pre_process, torch.nn.DataParallel): if hasattr(pre_process.module, "input_background"): self.input_background = pre_process.module.input_background else: if hasattr(pre_process, "input_background"): self.input_background = pre_process.input_background ( self.pre_process, self.pre_process_model, self.pre_process_param, ) = self._prepare_process_block(pre_process) def set_post_process(self, post_process): ( self.post_process, self.post_process_model, self.post_process_param, ) = self._prepare_process_block(post_process) def set_background_network(self, background_network): ( self.background_network, self.background_network_model, self.background_network_param, ) = self._prepare_process_block(background_network) def set_psf_network(self, psf_network): ( self.psf_network, self.psf_network_model, self.psf_network_param, ) = self._prepare_process_block(psf_network) def freeze_pre_process(self): """ Method for freezing the pre process block. """ if self.pre_process_param is not None: self.pre_process_param.requires_grad = False if self.pre_process_model is not None: for param in self.pre_process_model.parameters(): param.requires_grad = False def freeze_post_process(self): """ Method for freezing the post process block. """ if self.post_process_param is not None: self.post_process_param.requires_grad = False if self.post_process_model is not None: for param in self.post_process_model.parameters(): param.requires_grad = False def unfreeze_pre_process(self): """ Method for unfreezing the pre process block. """ if self.pre_process_param is not None: self.pre_process_param.requires_grad = True if self.pre_process_model is not None: for param in self.pre_process_model.parameters(): param.requires_grad = True def unfreeze_post_process(self): """ Method for unfreezing the post process block. """ if self.post_process_param is not None: self.post_process_param.requires_grad = True if self.post_process_model is not None: for param in self.post_process_model.parameters(): param.requires_grad = True
[docs] def forward(self, batch, psfs=None, background=None): """ Method for performing iterative reconstruction on a batch of images. This implementation is a properly vectorized implementation of FISTA. Parameters ---------- batch : :py:class:`~torch.Tensor` of shape (batch, depth, channels, height, width) The lensless images to reconstruct. psfs : :py:class:`~torch.Tensor` of shape (batch, depth, channels, height, width) The lensless images to reconstruct. Returns ------- :py:class:`~torch.Tensor` of shape (batch, depth, channels, height, width) The reconstructed images. """ self._data = batch assert len(self._data.shape) == 5, "batch must be of shape (N, D, C, H, W)" batch_size = batch.shape[0] if self.direct_background_subtraction: assert ( background is not None ), "If direct_background_subtraction is True, background must be defined." self._data = self._data - background self._data = torch.clamp(self._data, 0, 1) elif self.learned_background_subtraction: assert ( background is not None ), "If learned_background_subtraction is True, background must be defined." assert ( self.background_network is not None ), "If learned_background_subtraction is True, background_network must be defined." self._data = self._data - self.background_network( background, self.background_network_param ).to(self._data.device) self._data = torch.clamp(self._data, 0, 1) # set / transform PSFs if need be if self.psf_network is not None: if psfs is None: psfs = self._psf.to(self._data.device) if self.psf_residual: psfs = self.psf_network(psfs, self.psf_network_param).to(psfs.device) + psfs else: psfs = self.psf_network(psfs, self.psf_network_param).to(psfs.device) if psfs is not None: # assert same shape assert psfs.shape[-3:] == batch.shape[-3:], "psfs must have the same shape as batch" # -- update convolver self._convolver = RealFFTConvolve2D(psfs.to(self._data.device), **self._convolver_param) elif self._data.device != self._convolver._H.device: # needed for multi-GPU... TODO better solution? self._convolver = RealFFTConvolve2D( self._psf.to(self._data.device), **self._convolver_param ) # pre process data if self.integrated_background_subtraction: # use preprocess for background subtraction self._data = self.pre_process(self._data, noise_level=background) elif self.pre_process is not None and not self.skip_pre: # preproc that doesn't do background subtraction device_before = self._data.device self._data = self.pre_process( self._data, noise_level=self.pre_process_param, background=background if self.input_background else None, ) self._data = self._data.to(device_before) pre_processed = self._data self.reset(batch_size=batch_size) # unrolled algorithm if not self.skip_unrolled: if self.compensation_branch is not None: compensation_branch_inputs = [self._data] for i in range(self._n_iter): self._update(i) if self.compensation_branch is not None and i < self._n_iter - 1: compensation_branch_inputs.append(self._form_image()) image_est = self._form_image() else: image_est = self._data # post process data if self.post_process is not None and not self.skip_post: compensation_output = None if self.compensation_branch is not None: compensation_output = self.compensation_branch(compensation_branch_inputs) final_est = self.post_process( image_est, noise_level=self.post_process_param, compensation_output=compensation_output, ) else: final_est = image_est if self.return_intermediate: return final_est, image_est, pre_processed, psfs else: return final_est
[docs] def apply( self, disp_iter=10, plot_pause=0.2, plot=True, save=False, gamma=None, ax=None, reset=True, output_intermediate=False, background=None, ): """ Method for performing iterative reconstruction. Contrary to non-trainable reconstruction algorithm, the number of iteration isn't required. Note that `set_data` must be called beforehand. Parameters ---------- disp_iter : int How often to display and/or intermediate reconstruction (in number of iterations). If `None` OR `plot` or `save` are False, no intermediate reconstruction will be plotted/saved. plot_pause : float Number of seconds to pause after displaying reconstruction. plot : bool Whether to plot final result, and intermediate results if `disp_iter` is not None. save : bool Whether to save final result (as PNG), and intermediate results if `disp_iter` is not None. gamma : float, optional Gamma correction factor to apply for plots. Default is None. ax : :py:class:`~matplotlib.axes.Axes`, optional `Axes` object to fill for plotting/saving, default is to create one. output_intermediate : bool, optional Whether to output intermediate reconstructions after preprocessing and before postprocessing. Returns ------- final_im : :py:class:`~torch.Tensor` Final reconstruction. ax : :py:class:`~matplotlib.axes.Axes` `Axes` object on which final reconstruction is displayed. Only returning if `plot` or `save` is True. """ if self.direct_background_subtraction: assert ( background is not None ), "If direct_background_subtraction is True, background must be defined." self._data = self._data - background self._data = torch.clamp(self._data, 0, 1) elif self.learned_background_subtraction: assert ( background is not None ), "If learned_background_subtraction is True, background must be defined." assert ( self.background_network is not None ), "If learned_background_subtraction is True, background_network must be defined." self._data = self._data - self.background_network( background, self.background_network_param ).to(self._data.device) self._data = torch.clamp(self._data, 0, 1) # transform PSF if need be psf = None if self.psf_network is not None: psf = self._psf.to(self._data.device) if self.psf_residual: psf = self.psf_network(psf, self.psf_network_param).to(psf.device) + psf else: psf = self.psf_network(psf, self.psf_network_param).to(psf.device) if psf is not None: # -- update convolver self._convolver = RealFFTConvolve2D(psf.to(self._data.device), **self._convolver_param) pre_processed_image = None if self.integrated_background_subtraction: # use preprocess for background subtraction self._data = self.pre_process(self._data, noise_level=background) elif self.pre_process is not None and not self.skip_pre: # preproc that doesn't do background subtraction self._data = self.pre_process( self._data, noise_level=self.pre_process_param, background=background if self.input_background else None, ) if self.pre_process is not None and output_intermediate: pre_processed_image = self._data[0, ...].clone() if not self.skip_unrolled: im = super(TrainableReconstructionAlgorithm, self).apply( n_iter=self._n_iter, disp_iter=disp_iter, plot_pause=plot_pause, plot=plot, save=save, gamma=gamma, ax=ax, reset=reset, ) else: im = self._data # remove plot if returned if plot: im, _ = im # post process data pre_post_process_image = None if self.post_process is not None and not self.skip_post: compensation_output = None if self.compensation_branch is not None: compensation_output = self.compensation_branch(self.compensation_branch_inputs) # apply post process if output_intermediate: pre_post_process_image = im.clone() im = self.post_process( im, noise_level=self.post_process_param, compensation_output=compensation_output ) if plot: ax = plot_image(self._get_numpy_data(im[0]), ax=ax, gamma=gamma) ax.set_title( "Final reconstruction after {} iterations and post process".format(self._n_iter) ) if save: plt.savefig(plib.Path(save) / "final.png") if output_intermediate: return_items = im, pre_post_process_image, pre_processed_image if self.psf_network is not None: return_items += (psf,) return return_items elif plot: return im, ax else: return im