Reconstruction¶
Check out this notebook on Google Colab for an overview of the reconstruction algorithms available in LenslessPiCam (analytic and learned).
The core algorithmic component of LenslessPiCam is the abstract
class ReconstructionAlgorithm. The five reconstruction
strategies available in LenslessPiCam derive from this class:
GradientDescent: projected gradient descent with a non-negativity constraint. Two accelerated approaches are also available:NesterovGradientDescentandFISTA.ADMM: alternating direction method of multipliers (ADMM) with a non-negativity constraint and a total variation (TV) regularizer 1.APGD: accelerated proximal gradient descent with Pycsou as a backend. Any differentiable or proximal operator can be used as long as it is compatible with Pycsou, namely derives from one of DiffFunc or ProxFunc.UnrolledFISTA: unrolled FISTA with a non-negativity constraint.UnrolledADMM: unrolled ADMM with a non-negativity constraint and a total variation (TV) regularizer 1.
Note that the unrolled algorithms derive from the abstract class
TrainableReconstructionAlgorithm, which itself derives from
ReconstructionAlgorithm while adding functionality
for training on batches and adding trainable pre- and post-processing
blocks.
New reconstruction algorithms can be conveniently implemented by deriving from the abstract class and defining the following abstract methods:
the update step:
_update.a method to reset state variables:
reset.an image formation method:
_form_image.
One advantage of deriving from ReconstructionAlgorithm is
that functionality for iterating, saving, and visualization is already
implemented. Consequently, using a reconstruction algorithm that derives
from it boils down to three steps:
Creating an instance of the reconstruction algorithm.
Setting the data.
Applying the algorithm.
2D example (ADMM)¶
For example, for ADMM:
recon = ADMM(psf)
recon.set_data(data)
res = recon.apply(n_iter=n_iter)
A full running example can found in scripts/recon/admm.py and run as:
python scripts/recon/admm.py
Note that a YAML configuration script is defined in configs/admm_thumbs_up.yaml,
which is used by default. Individual parameters can be configured as such:
python scripts/recon/admm.py admm.n_iter=10 preprocess.gray=True
--help can be used to view all available parameters.
>> python scripts/recon/admm.py --help
...
== Config ==
Override anything in the config (foo.bar=value)
files:
psf: data/psf/tape_rgb.png
data: data/raw_data/thumbs_up_rgb.png
preprocess:
downsample: 4
shape: null
flip: false
bayer: false
blue_gain: null
red_gain: null
single_psf: false
gray: false
display:
disp: 1
plot: true
gamma: null
save: false
admm:
n_iter: 5
mu1: 1.0e-06
mu2: 1.0e-05
mu3: 4.0e-05
tau: 0.0001
...
Alternatively, a new configuration file can be defined in the configs folder and
passed to the script:
python scripts/recon/admm.py -cn <CONFIG_FILENAME_WITHOUT_YAML_EXT>
3D example¶
It is also possible to reconstruct 3D scenes using GradientDescent or APGD. ADMM does not support 3D reconstruction yet.
This requires to use a 3D PSF as an input in the form of an .npy or .npz file, which is a set of 2D PSFs corresponding to the same diffuser sampled with light sources at different depths.
The input data for 3D reconstructions is still a 2D image, as collected by the camera. The reconstruction will be able to separate which part of the lensless data corresponds to which 2D PSF,
and therefore to which depth, effectively generating a 3D reconstruction, which will be outputed in the form of an .npy file. A 2D projection on the depth axis is also displayed to the user.
The same scripts for 2D reconstruction can be used for 3D reconstruction, namely scripts/recon/gradient_descent.py and scripts/recon/apgd_pycsou.py.
3D data is provided in LenslessPiCam, but it is simulated. Real example data can be obtained from Waller Lab.
For both the simulated data and the data from Waller Lab, it is best to set downsample=1:
python scripts/recon/gradient_descent.py \
input.psf="path/to/3D/psf.npy" \
input.data="path/to/lensless/data.tiff" \
preprocess.downsample=1
Other approaches¶
Scripts for other reconstruction algorithms can be found in scripts/recon and their
corresponding configurations in configs.
References
Reconstruction API¶
Abstract Class¶
- class lensless.ReconstructionAlgorithm(psf, dtype=None, pad=True, n_iter=100, initial_est=None, reset=True, denoiser=None, **kwargs)[source]¶
Abstract class for defining lensless imaging reconstruction algorithms.
The following abstract methods need to be defined:
_update: updating state variables at each iterations.reset: reset state variables._form_image: any pre-processing that needs to be done in order to view the image estimate, e.g. reshaping or clipping.
One advantage of deriving from this abstract class is that functionality for iterating, saving, and visualization is already implemented, namely in the
applymethod.Consequently, using a reconstruction algorithm that derives from it boils down to three steps:
Creating an instance of the reconstruction algorithm.
Setting the data.
Applying the algorithm.
- __init__(psf, dtype=None, pad=True, n_iter=100, initial_est=None, reset=True, denoiser=None, **kwargs)[source]¶
Base constructor. Derived constructor may define new state variables here and also reset them in reset.
- Parameters
psf (
ndarrayorTensor) – Point spread function (PSF) that models forward propagation. Must be of shape (depth, height, width, channels) even if depth = 1 and channels = 1. You can useload_psf()to load a PSF from a file such that it is in the correct format.dtype (float32 or float64) – Data type to use for optimization.
pad (bool, optional) – Whether data needs to be padded prior to convolution. User may wish to optimize padded data and set this to False, as is done for
ADMM. Defaults to True.n_iter (int, optional) – Number of iterations to run algorithm for. Can be overridden in apply.
initial_est (
ndarrayorTensor, optional) – Initial estimate of the image. If not provided, the initial estimate is set to zero or to the mean of the data, depending on the algorithm.reset (bool, optional) – Whether to reset state variables in the base constructor. Defaults to True. If False, you should call reset() at one point to initialize state variables.
denoiser (dict, optional) –
Dictionary defining a denoiser for plug-and-play. Must contain the following keys:
"network": model to use as a denoiser."noise_level": noise level of the denoiser.
If provided, the denoiser will be used as a projection function at each iteration. Defaults to None.
- abstract _form_image()[source]¶
Any pre-processing to form a viewable image, e.g. reshaping or clipping.
- apply(n_iter=None, disp_iter=- 1, plot_pause=0.2, plot=False, save=False, gamma=None, ax=None, reset=True, background=None, **kwargs)[source]¶
Method for performing iterative reconstruction. Note that set_data must be called beforehand.
- Parameters
n_iter (int, optional) – Number of iterations. If not provided, default to self._n_iter.
disp_iter (int) – How often to display and/or intermediate reconstruction (in number of iterations). If None OR plot or save are False, no intermediate reconstruction will be plotted/saved.
plot_pause (float) – Number of seconds to pause after displaying reconstruction.
plot (bool) – Whether to plot final result, and intermediate results if disp_iter is not None.
save (bool) – Whether to save final result (as PNG), and intermediate results if disp_iter is not None.
gamma (float, optional) – Gamma correction factor to apply for plots. Default is None.
ax (
Axes, optional) – Axes object to fill for plotting/saving, default is to create one.reset (bool, optional) – Whether to reset state variables before applying reconstruction. Default to True. Set to false if continuing reconstruction from previous state.
- Returns
final_im (
ndarray) – Final reconstruction.ax (
Axes) – Axes object on which final reconstruction is displayed. Only returning if plot or save is True.
- get_image_estimate()[source]¶
Get current image estimate as [Batch, Depth, Height, Width, Channels].
- reconstruction_error(prediction=None, lensless=None, psfs=None, normalize=True)[source]¶
Compute reconstruction error.
Gradient Descent¶
- class lensless.GradientDescent(psf, dtype=None, proj=<function non_neg>, lip_fact=1.8, **kwargs)[source]¶
Object for applying projected gradient descent.
- __init__(psf, dtype=None, proj=<function non_neg>, lip_fact=1.8, **kwargs)[source]¶
- Parameters
psf (
ndarrayorTensor) – Point spread function (PSF) that models forward propagation. Must be of shape (depth, height, width, channels) even if depth = 1 and channels = 1. You can useload_psf()to load a PSF from a file such that it is in the correct format.dtype (float32 or float64) – Data type to use for optimization. Default is float32.
proj (
function) – Projection function to apply at each iteration. Default is non-negative.
- class lensless.NesterovGradientDescent(psf, dtype=None, proj=<function non_neg>, p=0, mu=0.9, **kwargs)[source]¶
Object for applying projected gradient descent with Nesterov momentum for acceleration.
A nice tutorial/ blog post on Nesterov momentum can be found here.
- __init__(psf, dtype=None, proj=<function non_neg>, p=0, mu=0.9, **kwargs)[source]¶
- Parameters
psf (
ndarrayorTensor) – Point spread function (PSF) that models forward propagation. Must be of shape (depth, height, width, channels) even if depth = 1 and channels = 1. You can useload_psf()to load a PSF from a file such that it is in the correct format.dtype (float32 or float64) – Data type to use for optimization. Default is float32.
proj (
function) – Projection function to apply at each iteration. Default is non-negative.p (float) – Momentum parameter that keeps track of changes. By default, this is initialized to 0.
mu (float) – Momentum parameter. Default is 0.9.
- class lensless.FISTA(psf, dtype=None, proj=<function non_neg>, tk=1.0, **kwargs)[source]¶
Object for applying projected gradient descent with FISTA (Fast Iterative Shrinkage-Thresholding Algorithm) for acceleration.
Paper: https://www.ceremade.dauphine.fr/~carlier/FISTA
- __init__(psf, dtype=None, proj=<function non_neg>, tk=1.0, **kwargs)[source]¶
- Parameters
psf (
ndarrayorTensor) – Point spread function (PSF) that models forward propagation. Must be of shape (depth, height, width, channels) even if depth = 1 and channels = 1. You can useload_psf()to load a PSF from a file such that it is in the correct format.dtype (float32 or float64) – Data type to use for optimization. Default is float32.
proj (
function) – Projection function to apply at each iteration. Default is non-negative.tk (float) – Initial step size parameter for FISTA. It is updated at each iteration according to Eq. 4.2 of paper. By default, initialized to 1.0.
ADMM¶
- class lensless.ADMM(psf, dtype=None, mu1=1e-06, mu2=1e-05, mu3=4e-05, tau=0.0001, psi=None, psi_adj=None, psi_gram=None, pad=False, norm='backward', denoiser=None, **kwargs)[source]¶
Object for applying ADMM (Alternating Direction Method of Multipliers) with a non-negativity constraint and a total variation (TV) prior.
Paper about ADMM: https://web.stanford.edu/~boyd/papers/pdf/admm_distr_stats.pdf
Slides about ADMM: https://web.stanford.edu/class/ee364b/lectures/admm_slides.pdf
- __init__(psf, dtype=None, mu1=1e-06, mu2=1e-05, mu3=4e-05, tau=0.0001, psi=None, psi_adj=None, psi_gram=None, pad=False, norm='backward', denoiser=None, **kwargs)[source]¶
- Parameters
psf (
ndarrayorTensor) – Point spread function (PSF) that models forward propagation. Must be of shape (depth, height, width, channels) even if depth = 1 and channels = 1. You can useload_psf()to load a PSF from a file such that it is in the correct format.dtype (float32 or float64) – Data type to use for optimization. Default is float32.
mu1 (float) – Step size for updating primal/dual variables.
mu2 (float) – Step size for updating primal/dual variables.
mu3 (float) – Step size for updating primal/dual variables.
tau (float) – Weight for L1 norm of psi applied to the image estimate.
psi (
function, optional) – Operator to map image to a space that the image is assumed to be sparse in (hence L1 norm). Default is to use total variation (TV) operator.psi_adj (
function) – Adjoint of psi.psi_gram (
function) – Function to compute gram of psi.pad (bool) – Whether to pad the image with zeros before applying the PSF. Default is False, as optimized data is already padded.
norm (str) – Normalization to use for the convolution. Options are “forward”, “backward”, and “ortho”. Default is “backward”.
Tikhonov (Ridge Regression)¶
- class lensless.CodedApertureReconstruction(mask, image_shape, P=None, Q=None, lmbd=0.0003)[source]¶
Reconstruction method for the (non-iterative) Tikhonov algorithm presented in the FlatCam paper.
TODO: operations in float32
- __init__(mask, image_shape, P=None, Q=None, lmbd=0.0003)[source]¶
- Parameters
mask (py:class:lensless.hardware.mask.CodedAperture) – Coded aperture mask object.
image_shape ((array-like or tuple)) – The shape of the image to reconstruct.
P (
ndarray, optional) – Left convolution matrix in measurement operator. Must be of shape (measurement_resolution[0], image_shape[0]). By default, it is generated from the mask. In practice, it may be useful to measure as in the FlatCam paper.Q (
ndarray, optional) – Right convolution matrix in measurement operator. Must be of shape (measurement_resolution[1], image_shape[1]). By default, it is generated from the mask. In practice, it may be useful to measure as in the FlatCam paper.lmbd (float:) – Regularization parameter. Default value is 3e-4 as in the FlatCam paper code.
Accelerated Proximal Gradient Descent (APGD)¶
- class lensless.recon.apgd.APGD(psf, max_iter=500, dtype='float32', diff_penalty=None, prox_penalty='nonneg', acceleration=True, diff_lambda=0.001, prox_lambda=0.001, disp=100, rel_error=None, lipschitz_tight=True, lipschitz_tol=1.0, img_shape=None, **kwargs)[source]¶
- __init__(psf, max_iter=500, dtype='float32', diff_penalty=None, prox_penalty='nonneg', acceleration=True, diff_lambda=0.001, prox_lambda=0.001, disp=100, rel_error=None, lipschitz_tight=True, lipschitz_tol=1.0, img_shape=None, **kwargs)[source]¶
Wrapper for Pycsou’s PGD (accelerated proximal gradient descent) applied to lensless imaging.
- Parameters
psf (
ndarray) – Point spread function (PSF) that models forward propagation. Must be of shape (depth, height, width, channels) even if depth = 1 and channels = 1. You can useload_psf()to load a PSF from a file such that it is in the correct format.max_iter (int, optional) – Maximal number of iterations.
dtype (float32 or float64) – Data type to use for optimization.
diff_penalty (None or str or DiffFunc) –
Differentiable functional to serve as prior / regularization term. Default is None. See DiffFunc.
prox_penalty (None or str or ProxFunc) –
Proximal functional to serve as prior / regularization term. Default is non-negative prior. See ProxFunc.
acceleration (bool, optional) – Whether to use acceleration or not. Default is True.
diff_lambda (float) – Weight of differentiable penalty.
prox_lambda (float) – Weight of proximal penalty.
disp (int, optional) – Display frequency. Default is 100.
rel_error (float, optional) – Relative error to stop optimization. Default is 1e-6.
lipschitz_tight (bool, optional) – Whether to use tight Lipschitz constant or not. Default is True.
lipschitz_tol (float, optional) – Tolerance to compute Lipschitz constant. Default is 1.
img_shape (tuple, optional) – Shape of measurement (H, W, C). If None, assume shape of PSF.
Trainable reconstruction API¶
Abstract Class (Trainable)¶
- class lensless.TrainableReconstructionAlgorithm(psf, dtype=None, n_iter=1, pre_process=None, post_process=None, skip_unrolled=False, skip_pre=False, skip_post=False, return_intermediate=False, legacy_denoiser=False, compensation=None, compensation_residual=True, psf_network=None, psf_residual=True, direct_background_subtraction=False, background_network=None, integrated_background_subtraction=False, **kwargs)[source]¶
Bases:
ReconstructionAlgorithm,ModuleAbstract class for defining lensless imaging reconstruction algorithms with trainable parameters.
The following abstract methods need to be defined:
_update: updating state variables at each iterations.reset: reset state variables._form_image: any pre-processing that needs to be done in order to view the image estimate, e.g. reshaping or clipping.
One advantage of deriving from this abstract class is that functionality for iterating, saving, and visualization is already implemented, namely in the
applymethod.Consequently, using a reconstruction algorithm that derives from it boils down to four steps:
Creating an instance of the reconstruction algorithm.
Training the algorithm
Setting the data.
Applying the algorithm.
- __init__(psf, dtype=None, n_iter=1, pre_process=None, post_process=None, skip_unrolled=False, skip_pre=False, skip_post=False, return_intermediate=False, legacy_denoiser=False, compensation=None, compensation_residual=True, psf_network=None, psf_residual=True, direct_background_subtraction=False, background_network=None, integrated_background_subtraction=False, **kwargs)[source]¶
Base constructor. Derived constructor may define new state variables here and also reset them in reset.
- Parameters
psf (
Tensor) – Point spread function (PSF) that models forward propagation. Must be of shape (depth, height, width, channels) even if depth = 1 and channels = 1. You can useload_psf()to load a PSF from a file such that it is in the correct format.dtype (float32 or float64) – Data type to use for optimization.
n_iter (int) – Number of iterations for unrolled algorithm.
pre_process (
functionorModule, optional) – Iffunction: Function to apply to the image estimate before algorithm. Its input most be (image to process, noise_level), where noise_level is a learnable parameter. If it include aditional learnable parameters, they will not be added to the parameter list of the algorithm. To allow for traning, the function must be autograd compatible. IfModule: A DruNet compatible network to apply to the image estimate before algorithm. Seeutils.image.apply_denoiserfor more details. The network will be included as a submodule of the algorithm and its parameters will be added to the parameter list of the algorithm. If this isn’t intended behavior, set requires_grad=False.post_process (
functionorModule, optional) – Iffunction: Function to apply to the image estimate after the whole algorithm. Its input most be (image to process, noise_level), where noise_level is a learnable parameter. If it include aditional learnable parameters, they will not be added to the parameter list of the algorithm. To allow for traning, the function must be autograd compatible. IfModule: A DruNet compatible network to apply to the image estimate after the whole algorithm. Seeutils.image.apply_denoiserfor more details. The network will be included as a submodule of the algorithm and its parameters will be added to the parameter list of the algorithm. If this isn’t intended behavior, set requires_grad=False.skip_unrolled (bool, optional) – Whether to skip the unrolled algorithm and only apply the pre- or post-processor block (e.g. to just use a U-Net for reconstruction).
return_unrolled_output (bool, optional) – Whether to return the output of the unrolled algorithm if also using a post-processor block.
compensation (list, optional) – Number of channels for each intermediate output in compensation layer, as in “Robust Reconstruction With Deep Learning to Handle Model Mismatch in Lensless Imaging” (2021). Post-processor must be defined if compensation provided.
compensation_residual (bool, optional) – Whether to use residual connection in compensation layer.
psf_network (
functionorModule, optional) – Function or model to apply to PSF prior to camera inversion.psf_residual (bool, optional) – Whether to use residual connection in PSF network.
- apply(disp_iter=10, plot_pause=0.2, plot=True, save=False, gamma=None, ax=None, reset=True, output_intermediate=False, background=None)[source]¶
Method for performing iterative reconstruction. Contrary to non-trainable reconstruction algorithm, the number of iteration isn’t required. Note that set_data must be called beforehand.
- Parameters
disp_iter (int) – How often to display and/or intermediate reconstruction (in number of iterations). If None OR plot or save are False, no intermediate reconstruction will be plotted/saved.
plot_pause (float) – Number of seconds to pause after displaying reconstruction.
plot (bool) – Whether to plot final result, and intermediate results if disp_iter is not None.
save (bool) – Whether to save final result (as PNG), and intermediate results if disp_iter is not None.
gamma (float, optional) – Gamma correction factor to apply for plots. Default is None.
ax (
Axes, optional) – Axes object to fill for plotting/saving, default is to create one.output_intermediate (bool, optional) – Whether to output intermediate reconstructions after preprocessing and before postprocessing.
- Returns
final_im (
Tensor) – Final reconstruction.ax (
Axes) – Axes object on which final reconstruction is displayed. Only returning if plot or save is True.
- forward(batch, psfs=None, background=None)[source]¶
Method for performing iterative reconstruction on a batch of images. This implementation is a properly vectorized implementation of FISTA.
- Parameters
- Returns
The reconstructed images.
- Return type
Tensorof shape (batch, depth, channels, height, width)
- abstract reset()¶
Reset state variables.
Unrolled FISTA¶
- class lensless.UnrolledFISTA(psf, n_iter=5, dtype=None, proj=<function non_neg>, learn_tk=True, tk=1, **kwargs)[source]¶
Bases:
TrainableReconstructionAlgorithmObject for applying unrolled projected gradient descent with FISTA (Fast Iterative Shrinkage-Thresholding Algorithm) for acceleration.
FISTA Paper: https://www.ceremade.dauphine.fr/~carlier/FISTA
- __init__(psf, n_iter=5, dtype=None, proj=<function non_neg>, learn_tk=True, tk=1, **kwargs)[source]¶
COnstructor for unrolled FISTA algorithm.
- Parameters
psf (
Tensor) – Point spread function (PSF) that models forward propagation. Must be of shape (depth, height, width, channels) even if depth = 1 and channels = 1. You can useload_psf()to load a PSF from a file such that it is in the correct format.n_iter (int, optional) – Number of iterations to unrolled, by default 5
dtype (float32 or float64) – Data type to use for optimization.
proj (
function, optional) – Projection function to apply at each iteration, by default non_neglearn_tk (bool, optional) – whether the tk parameters of FISTA should be learnt, by default True
tk (int, optional) – Initial value of tk, by default 1
Unrolled ADMM¶
- class lensless.UnrolledADMM(psf, dtype=None, n_iter=5, mu1=1e-06, mu2=1e-05, mu3=4e-05, tau=0.0001, psi=None, psi_adj=None, psi_gram=None, pad=False, norm='backward', **kwargs)[source]¶
Bases:
TrainableReconstructionAlgorithmObject for applying unrolled version of
ADMM.- __init__(psf, dtype=None, n_iter=5, mu1=1e-06, mu2=1e-05, mu3=4e-05, tau=0.0001, psi=None, psi_adj=None, psi_gram=None, pad=False, norm='backward', **kwargs)[source]¶
- Parameters
psf (
Tensor) – Point spread function (PSF) that models forward propagation. Must be of shape (depth, height, width, channels) even if depth = 1 and channels = 1. You can useload_psf()to load a PSF from a file such that it is in the correct format.dtype (float32 or float64) – Data type to use for optimization. Default is float32.
n_iter (int, optional) – Number of iterations to unrolled, by default 5
mu1 (float) – Initial step size for updating primal/dual variables.
mu2 (float) – Initial step size for updating primal/dual variables.
mu3 (float) – Initial step size for updating primal/dual variables.
tau (float) – Initial weight for L1 norm of psi applied to the image estimate.
psi (
function, optional) – Operator to map image to a space that the image is assumed to be sparse in (hence L1 norm). Default is to use total variation (TV) operator.psi_adj (
function) – Adjoint of psi.psi_gram (
function) – Function to compute gram of psi.pad (bool, optional) – Whether to pad the image with zeros before applying the PSF.
norm (str) – Normalization to use for the convolution. Options are “forward”, “backward”, and “ortho”. Default is “backward”.
Trainable Inversion¶
- class lensless.TrainableInversion(psf, dtype=None, K=0.0001, **kwargs)[source]¶
Bases:
TrainableReconstructionAlgorithm- __init__(psf, dtype=None, K=0.0001, **kwargs)[source]¶
Constructor for trainable inversion component as proposed in the FlatNet work: https://siddiquesalman.github.io/flatnet/
Their implementation: https://github.com/siddiquesalman/flatnet/blob/d1dc179666a08df58c4bdf83c4274310ba3cd186/models/fftlayer.py#L70
- Parameters
psf (
Tensor) – Point spread function (PSF) that models forward propagation. Must be of shape (depth, height, width, channels) even if depth = 1 and channels = 1. You can useload_psf()to load a PSF from a file such that it is in the correct format.dtype (float32 or float64) – Data type to use for optimization.
K (float) – Regularization parameter.
Multi-Wiener Deconvolution Network¶
- class lensless.MultiWiener(in_channels, out_channels, psf, psf_channels=1, nc=None, pre_process=None, skip_pre=False)[source]¶
Bases:
Module- __init__(in_channels, out_channels, psf, psf_channels=1, nc=None, pre_process=None, skip_pre=False)[source]¶
Constructor for Multi-Wiener Deconvolution Network (MWDN) as proposed in: https://opg.optica.org/oe/fulltext.cfm?uri=oe-31-23-39088&id=541387
- Parameters
in_channels (int) – Number of input channels. RGB or grayscale, i.e. 3 and 1 respectively.
out_channels (int) – Number of output channels. RGB or grayscale, i.e. 3 and 1 respectively.
psf (
Tensor) – Point spread function (PSF) that models forward propagation.psf_channels (int) – Number of channels in the PSF. Default is 1.
nc (list) – Number of channels in the network. Default is [64, 128, 256, 512, 512].
pre_process (
functionorModule, optional) – Pre-processor applies before MWDN. Default is None.skip_pre (bool) – Skip pre-processing. Default is False.
- forward(batch, psfs=None, **kwargs)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Reconstruction Utilities¶
- class lensless.recon.utils.Trainer(recon, train_dataset, test_dataset, test_size=0.15, mask=None, batch_size=4, eval_batch_size=10, loss='l2', lpips=None, l1_mask=None, optimizer=None, skip_NAN=False, algorithm_name='Unknown', metric_for_best_model=None, save_every=None, gamma=None, logger=None, crop=None, clip_grad=1.0, unrolled_output_factor=False, random_rotate=False, random_shift=False, pre_proc_aux=False, extra_eval_sets=None, use_wandb=False, pre_process=None, pre_process_delay=None, pre_process_freeze=None, pre_process_unfreeze=None, post_process=None, post_process_delay=None, post_process_freeze=None, post_process_unfreeze=None, n_epoch=None)[source]¶
- __init__(recon, train_dataset, test_dataset, test_size=0.15, mask=None, batch_size=4, eval_batch_size=10, loss='l2', lpips=None, l1_mask=None, optimizer=None, skip_NAN=False, algorithm_name='Unknown', metric_for_best_model=None, save_every=None, gamma=None, logger=None, crop=None, clip_grad=1.0, unrolled_output_factor=False, random_rotate=False, random_shift=False, pre_proc_aux=False, extra_eval_sets=None, use_wandb=False, pre_process=None, pre_process_delay=None, pre_process_freeze=None, pre_process_unfreeze=None, post_process=None, post_process_delay=None, post_process_freeze=None, post_process_unfreeze=None, n_epoch=None)[source]¶
Class to train a reconstruction algorithm. Inspired by Trainer from HuggingFace.
The train and test metrics at the end of each epoch can be found in
self.metrics, with “LOSS” being the train loss. The test loss can be found in “MSE” (if loss is “l2”) or “MAE” (if loss is “l1”). Iflpipsis not None, the LPIPS loss is also added to the train loss, such that the test loss can be computed as “MSE” +lpips* “LPIPS_Vgg” (or “MAE” +lpips* “LPIPS_Vgg”).- Parameters
recon (
lensless.TrainableReconstructionAlgorithm) – Reconstruction algorithm to train.train_dataset (
torch.utils.data.Dataset) – Dataset to use for training.test_dataset (
torch.utils.data.Dataset) – Dataset to use for testing.test_size (float, optional) – If test_dataset is None, fraction of the train dataset to use for testing, by default 0.15.
mask (TrainableMask, optional) – Trainable mask to use for training. If none, training with fix psf, by default None.
batch_size (int, optional) – Batch size to use for training, by default 4.
eval_batch_size (int, optional) – Batch size to use for evaluation, by default 10.
loss (str, optional) – Loss function to use for training “l1” or “l2”, by default “l2”.
lpips (float, optional) – the weight of the lpips(VGG) in the total loss. If None ignore. By default None.
l1_mask (float, optional) – the weight of the l1 norm of the mask in the total loss. If None ignore. By default None.
optimizer (dict) – Optimizer configuration.
skip_NAN (bool, optional) – Whether to skip update if any gradiant are NAN (True) or to throw an error(False), by default False
algorithm_name (str, optional) – Algorithm name for logging, by default “Unknown”.
metric_for_best_model (str, optional) – Metric to use for saving the best model. If None, will default to evaluation loss. Default is None.
save_every (int, optional) – Save model every
save_everyepochs. If None, just save best model.gamma (float, optional) – Gamma correction to apply to PSFs when plotting. If None, no gamma correction is applied. Default is None.
logger (
logging.Logger, optional) – Logger to use for logging. If None, just print to terminal. Default is None.crop (dict, optional) – Crop to apply to images before computing loss (by applying a mask). If None, no crop is applied. Default is None.
unrolled_output_factor (float, optional) – How much of the unrolled loss to add to the total loss. If False, no unrolled loss is added. Default is False. Only applicable if a post-processor is used.
pre_process (
torch.nn.Module, optional) – Pre process component to add during training. Default is None.pre_process_delay (int, optional) – Epoch at which to add pre process component. Default is None.
pre_process_freeze (int, optional) – Epoch at which to freeze pre process component. Default is None.
pre_process_unfreeze (int, optional) – Epoch at which to unfreeze pre process component. Default is None.
post_process (
torch.nn.Module, optional) – Post process component to add during training. Default is None.post_process_delay (int, optional) – Epoch at which to add post process component. Default is None.
post_process_freeze (int, optional) – Epoch at which to freeze post process component. Default is None.
post_process_unfreeze (int, optional) – Epoch at which to unfreeze post process component. Default is None.
- evaluate(mean_loss, epoch, disp=None)[source]¶
Evaluate the reconstruction algorithm on the test dataset.
- Parameters
mean_loss (float) – Mean loss of the last epoch.
disp (list of int, optional) – Test set examples to visualize at the end of each epoch, by default None.
- train_epoch(data_loader)[source]¶
Train for one epoch.
- Parameters
data_loader (
torch.utils.data.DataLoader) – Data loader to use for training.- Returns
Mean loss of the epoch.
- Return type
- lensless.recon.utils.load_drunet(model_path=None, n_channels=3, requires_grad=False)[source]¶
Load a pre-trained Drunet model.
- Parameters
- Returns
model – Loaded model.
- Return type
- lensless.recon.utils.apply_denoiser(model, image, noise_level=10, mode='inference', compensation_output=None, background=None)[source]¶
Apply a pre-trained denoising model with input in the format Channel, Height, Width. An additionnal channel is added for the noise level as done in Drunet.
- Parameters
model (
torch.nn.Module) – Drunet compatible model. Its input must consist of 4 channels (RGB + noise level) and output an RGB image both in CHW format.image (
torch.Tensor) – Input image.noise_level (float or
torch.Tensor) – Noise level in the image within [0, 255].background (
torch.Tensor, optional) – If provided, use background as noise channel instead of noise level.device (str) – Device to use for computation. Can be “cpu” or “cuda”.
mode (str) – Mode to use for model. Can be “inference” or “train”.
- Returns
image – Reconstructed image.
- Return type
- lensless.recon.utils.get_drunet_function(model, mode='inference')[source]¶
Return a processing function that applies the DruNet model to an image. Legacy function to work with pre-trained models, use get_drunet_function_v2 instead.
- Parameters
model (
torch.nn.Module) – DruNet like denoiser modeldevice (str) – Device to use for computation. Can be “cpu” or “cuda”.
mode (str) – Mode to use for model. Can be “inference” or “train”.
- lensless.recon.utils.measure_gradient(model)[source]¶
Helper function to measure L2 norm of the gradient of a model.
- Parameters
model (
torch.nn.Module) – Model to measure gradient of.- Returns
L2 norm of the gradient of the model.
- Return type
Float
- lensless.recon.utils.create_process_network(network, device='cpu', device_ids=None, concatenate_compensation=False, background_subtraction=False, input_background=False, depth=4, nc=None, restormer_params=None)[source]¶
Helper function to create a process network.
- Parameters
- Returns
New process network. Already trained for Drunet.
- Return type