Source code for lensless.utils.simulation

# #############################################################################
# simulation.py
# =================
# Authors :
# Yohann PERRON [yohann.perron@gmail.com]
# Eric BEZZAM [ebezzam@gmail.com]
# #############################################################################

from waveprop.simulation import FarFieldSimulator as FarFieldSimulator_wp
import torch


[docs]class FarFieldSimulator(FarFieldSimulator_wp): """ LenslessPiCam-compatible wrapper for :py:class:`~waveprop.simulation.FarFieldSimulator` (source code on `GitHub <https://github.com/ebezzam/waveprop/blob/82dfb08b4db11c0c07ef00bdb59b5a769a49f0b3/waveprop/simulation.py#L11C11-L11C11>`__). """
[docs] def __init__( self, object_height, scene2mask, mask2sensor, sensor, psf=None, output_dim=None, snr_db=None, max_val=255, device_conv="cpu", random_shift=False, is_torch=False, quantize=True, **kwargs ): """ Parameters ---------- psf : np.ndarray or torch.Tensor, optional. Point spread function. If not provided, return image at object plane. object_height : float or (float, float) Height of object in meters. Or range of values to randomly sample from. scene2mask : float Distance from scene to mask in meters. mask2sensor : float Distance from mask to sensor in meters. sensor : str Sensor name. snr_db : float, optional Signal-to-noise ratio in dB, by default None. max_val : int, optional Maximum value of image, by default 255. device_conv : str, optional Device to use for convolution (when using pytorch), by default "cpu". random_shift : bool, optional Whether to randomly shift the image, by default False. is_torch : bool, optional Whether to use pytorch, by default False. quantize : bool, optional Whether to quantize image, by default True. """ if psf is not None: assert len(psf.shape) == 4, "PSF must be of shape (depth, height, width, channels)" if torch.is_tensor(psf): # drop depth dimension, and convert HWC to CHW psf = psf[0].movedim(-1, 0) assert psf.shape[0] == 1 or psf.shape[0] == 3, "PSF must have 1 or 3 channels" else: psf = psf[0] assert psf.shape[-1] == 1 or psf.shape[-1] == 3, "PSF must have 1 or 3 channels" super().__init__( object_height, scene2mask, mask2sensor, sensor, psf, output_dim, snr_db, max_val, device_conv, random_shift, is_torch, quantize, **kwargs ) if psf is not None: if self.is_torch: assert ( self.psf.shape[0] == 1 or self.psf.shape[0] == 3 ), "PSF must have 1 or 3 channels" else: assert ( self.psf.shape[-1] == 1 or self.psf.shape[-1] == 3 ), "PSF must have 1 or 3 channels" # save all the parameters in a dict self.params = { "object_height": object_height, "scene2mask": scene2mask, "mask2sensor": mask2sensor, "sensor": sensor, "output_dim": output_dim, "snr_db": snr_db, "max_val": max_val, "device_conv": device_conv, "random_shift": random_shift, "is_torch": is_torch, "quantize": quantize, } self.params.update(kwargs)
def get_psf(self): if self.is_torch: # convert CHW to HWC return self.psf.movedim(0, -1).unsqueeze(0) else: return self.psf[None, ...] # needs different name from parent class
[docs] def set_point_spread_function(self, psf): """ Set point spread function. Parameters ---------- psf : np.ndarray or torch.Tensor Point spread function. """ assert len(psf.shape) == 4, "PSF must be of shape (depth, height, width, channels)" if torch.is_tensor(psf): # convert HWC to CHW psf = psf[0].movedim(-1, 0) assert psf.shape[0] == 1 or psf.shape[0] == 3, "PSF must have 1 or 3 channels" else: psf = psf[0] assert psf.shape[-1] == 1 or psf.shape[-1] == 3, "PSF must have 1 or 3 channels" return super().set_psf(psf)
[docs] def propagate_image(self, obj, return_object_plane=False): """ Parameters ---------- obj : np.ndarray or torch.Tensor Single image to propagate of format HWC. return_object_plane : bool, optional Whether to return object plane, by default False. """ assert obj.shape[-1] == 1 or obj.shape[-1] == 3, "Image must have 1 or 3 channels" if self.is_torch: # channel in first dimension as expected by waveprop for pytorch obj = obj.moveaxis(-1, 0) res = super().propagate(obj, return_object_plane) if isinstance(res, tuple): res = res[0].moveaxis(-3, -1), res[1].moveaxis(-3, -1) else: res = res.moveaxis(-3, -1) return res else: # TODO: not tested, but normally don't need to move dimensions for numpy res = super().propagate(obj, return_object_plane) return res