Source code for lensless.hardware.trainable_mask

# #############################################################################
# trainable_mask.py
# ==================
# Authors :
# Yohann PERRON [yohann.perron@gmail.com]
# Eric BEZZAM [ebezzam@gmail.com]
# #############################################################################

import abc
import omegaconf
import numpy as np
import torch
from lensless.utils.image import is_grayscale, rgb2gray
from lensless.hardware.slm import full2subpattern, get_programmable_mask, get_intensity_psf
from lensless.hardware.sensor import VirtualSensor
from waveprop.devices import slm_dict
from lensless.hardware.mask import CodedAperture


[docs]class TrainableMask(torch.nn.Module, metaclass=abc.ABCMeta): """ Abstract class for defining trainable masks. The following abstract methods need to be defined: - :py:class:`~lensless.hardware.trainable_mask.TrainableMask.get_psf`: returning the PSF of the mask. - :py:class:`~lensless.hardware.trainable_mask.TrainableMask.project`: projecting the mask parameters to a valid space (should be a subspace of [0,1]). """
[docs] def __init__(self, optimizer="Adam", lr=1e-3, **kwargs): """ Base constructor. Derived constructor may define new state variables Parameters ---------- optimizer : str, optional Optimizer to use for updating the mask parameters, by default "Adam" lr : float, optional Learning rate for the mask parameters, by default 1e-3 """ super().__init__() self._optimizer = optimizer self._lr = lr self._counter = 0
def _set_optimizer(self, param): """Set the optimizer for the mask parameters.""" self._optimizer = getattr(torch.optim, self._optimizer)(param, lr=self._lr)
[docs] @abc.abstractmethod def get_psf(self): """ Abstract method for getting the PSF of the mask. Should be fully compatible with pytorch autograd. Returns ------- :py:class:`~torch.Tensor` The PSF of the mask. """ raise NotImplementedError
[docs] def update_mask(self): """Update the mask parameters. According to externaly updated gradiants.""" self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) self.project() self._counter += 1
[docs] @abc.abstractmethod def project(self): """Abstract method for projecting the mask parameters to a valid space (should be a subspace of [0,1]).""" raise NotImplementedError
[docs]class TrainablePSF(TrainableMask): # class TrainablePSF(torch.nn.Module, TrainableMask): """ Class for defining an object that directly optimizes the PSF, without any constraints on what can be realized physically. Parameters ---------- grayscale : bool, optional Whether mask should be returned as grayscale when calling :py:class:`~lensless.hardware.trainable_mask.TrainableMask.get_psf`. Otherwise PSF will be returned as RGB. By default False. """
[docs] def __init__(self, initial_psf, grayscale=False, **kwargs): super().__init__(**kwargs) self._psf = torch.nn.Parameter(initial_psf) initial_param = [self._psf] self._set_optimizer(initial_param) # checks assert len(initial_psf.shape) == 4, "Mask must be of shape (depth, height, width, channels)" self.grayscale = grayscale self._is_grayscale = is_grayscale(initial_psf) if grayscale: assert self._is_grayscale, "PSF must be grayscale"
[docs] def get_psf(self): if self._is_grayscale: if self.grayscale: # simulation in grayscale return self._psf else: # replicate to 3 channels return self._psf.expand(-1, -1, -1, 3) else: # assume RGB return self._psf
[docs] def project(self): self._psf.data = torch.clamp(self._psf, 0, 1)
class AdafruitLCD(TrainableMask): def __init__( self, initial_vals, sensor, slm, train_mask_vals=True, color_filter=None, rotate=None, flipud=False, use_waveprop=False, vertical_shift=None, horizontal_shift=None, scene2mask=None, mask2sensor=None, deadspace=True, downsample=None, min_val=0, **kwargs, ): """ Parameters ---------- initial_vals : :py:class:`~torch.Tensor` Initial mask parameters. sensor : :py:class:`~lensless.hardware.sensor.VirtualSensor` Sensor object. slm_param : :py:class:`~lensless.hardware.slm.SLMParam` SLM parameters. rotate : float, optional Rotation angle in degrees, by default None. flipud : bool, optional Whether to flip the mask vertically, by default False. use_waveprop : bool, optional Whether to use wave propagation for simulating PSF. If False, PSF will simply be intensity of mask pattern, by default False. vertical_shift : int, optional Vertical shift of the mask, by default None. horizontal_shift : int, optional Horizontal shift of the mask, by default None. scene2mask : :py:class:`~torch.Tensor`, optional Distance from scene to mask. Used for wave propagation, by default None. mask2sensor : :py:class:`~torch.Tensor`, optional Distance from mask to sensor. Used for wave propagation, by default None. downsample : int, optional Downsample factor, by default None. min_val : float, optional Minimum value for mask weights, by default 0. """ super().__init__(**kwargs) self.train_mask_vals = train_mask_vals if train_mask_vals: self._vals = torch.nn.Parameter(initial_vals) else: self._vals = initial_vals initial_param = None if color_filter is not None: self._color_filter = torch.nn.Parameter(color_filter) if train_mask_vals: initial_param = [self._vals, self._color_filter] else: initial_param = [self._color_filter] else: assert ( train_mask_vals ), "If color filter is not trainable, mask values must be trainable" initial_param = [self._vals] self._color_filter = None assert initial_param is not None, "Initial parameters must be set" # set optimizer self._set_optimizer(initial_param) self.slm_param = slm_dict[slm] self.device = slm self.sensor = VirtualSensor.from_name(sensor, downsample=downsample) self.rotate = rotate self.flipud = flipud self.use_waveprop = use_waveprop self.scene2mask = scene2mask self.mask2sensor = mask2sensor self.deadspace = deadspace self.vertical_shift = vertical_shift self.horizontal_shift = horizontal_shift self.min_val = min_val if downsample is not None and vertical_shift is not None: self.vertical_shift = vertical_shift // downsample if downsample is not None and horizontal_shift is not None: self.horizontal_shift = horizontal_shift // downsample if self.use_waveprop: assert self.scene2mask is not None assert self.mask2sensor is not None def get_psf(self): mask = get_programmable_mask( vals=self._vals, sensor=self.sensor, slm_param=self.slm_param, rotate=self.rotate, flipud=self.flipud, color_filter=self._color_filter, deadspace=self.deadspace, ) if self.vertical_shift is not None: mask = torch.roll(mask, self.vertical_shift, dims=1) if self.horizontal_shift is not None: mask = torch.roll(mask, self.horizontal_shift, dims=2) psf_in = get_intensity_psf( mask=mask, sensor=self.sensor, waveprop=self.use_waveprop, scene2mask=self.scene2mask, mask2sensor=self.mask2sensor, ) # add first dimension (depth) psf_in = psf_in.unsqueeze(0) # move channels to last dimension psf_in = psf_in.permute(0, 2, 3, 1) # flip mask psf_in = torch.flip(psf_in, dims=[-3, -2]) # normalize psf_in = psf_in / psf_in.norm() return psf_in def project(self): if self.train_mask_vals: self._vals.data = torch.clamp(self._vals, self.min_val, 1) if self._color_filter is not None: self._color_filter.data = torch.clamp(self._color_filter, 0, 1) # normalize each row to 1 self._color_filter.data = self._color_filter / self._color_filter.sum( dim=[1, 2] ).unsqueeze(-1).unsqueeze(-1) class TrainableCodedAperture(TrainableMask): def __init__( self, sensor_name, downsample=None, binary=True, torch_device="cuda", **kwargs, ): """ TODO: Distinguish between separable and non-separable. """ # 1) call base constructor so parameters can be set super().__init__(**kwargs) # 2) initialize mask assert "distance_sensor" in kwargs, "Distance to sensor must be specified" assert "method" in kwargs, "Method must be specified." assert "n_bits" in kwargs, "Number of bits must be specified." self._mask_obj = CodedAperture.from_sensor( sensor_name, downsample, is_torch=True, torch_device=torch_device, **kwargs, ) # 3) set learnable parameters (should be immediate attributes of the class) self._row = None self._col = None self._mask = None if self._mask_obj.row is not None: # separable self.separable = True self._row = torch.nn.Parameter(self._mask_obj.row) self._col = torch.nn.Parameter(self._mask_obj.col) initial_param = [self._row, self._col] else: # non-separable self.separable = False self._mask = torch.nn.Parameter(self._mask_obj.mask) initial_param = [self._mask] self.binary = binary # 4) set optimizer self._set_optimizer(initial_param) # 5) compute PSF self._psf = None self.project() def get_psf(self): return self._psf def project(self): with torch.no_grad(): if self.separable: self._row.data = torch.clamp(self._row, 0, 1) self._col.data = torch.clamp(self._col, 0, 1) if self.binary: self._row.data = torch.round(self._row) self._col.data = torch.round(self._col) else: self._mask.data = torch.clamp(self._mask, 0, 1) if self.binary: self._mask.data = torch.round(self._mask) # recompute PSF self._mask_obj.create_mask(self._row, self._col, mask=self._mask) self._mask_obj.compute_psf() self._psf = self._mask_obj.psf.unsqueeze(0) self._psf = self._psf / self._psf.norm() """ Utility functions to help prepare trainable masks. """ mask_type_to_class = { "AdafruitLCD": AdafruitLCD, "TrainablePSF": TrainablePSF, "TrainableCodedAperture": TrainableCodedAperture, "TrainableHeightVarying": None, "TrainableMultiLensArray": None, } def prep_trainable_mask(config, psf=None, downsample=None): mask = None color_filter = None downsample = config["files"]["downsample"] if downsample is None else downsample mask_type = config["trainable_mask"]["mask_type"] if mask_type is not None: assert mask_type in mask_type_to_class.keys(), ( f"Trainable mask type {mask_type} not supported. " f"Supported types are {mask_type_to_class.keys()}" ) mask_class = mask_type_to_class[mask_type] # -- trainable mask object if isinstance(config["trainable_mask"]["initial_value"], omegaconf.dictconfig.DictConfig): # from mask config mask = mask_class( # mask = TrainableCodedAperture( sensor_name=config.simulation.sensor, downsample=downsample, distance_sensor=config.simulation.mask2sensor, optimizer=config.trainable_mask.optimizer, lr=config.trainable_mask.mask_lr, binary=config.trainable_mask.binary, torch_device=config.torch_device, **config.trainable_mask.initial_value, ) else: if config["trainable_mask"]["initial_value"] == "random": if psf is not None: initial_mask = torch.rand_like(psf) else: sensor = VirtualSensor.from_name( config["simulation"]["sensor"], downsample=downsample ) resolution = sensor.resolution initial_mask = torch.rand((1, *resolution, 3)) elif config["trainable_mask"]["initial_value"] == "psf": if config["reconstruction"]["method"] == "svdeconvnet": # learnable PSF as part of SVDeconvNet n_psf = config["reconstruction"]["svdeconvnet"]["K"] ** 2 initial_mask = psf.repeat(n_psf, 1, 1, 1) else: initial_mask = psf.clone() # if file ending with "npy" elif config["trainable_mask"]["initial_value"].endswith("npy"): pattern = np.load(config["trainable_mask"]["initial_value"]) initial_mask = full2subpattern( pattern=pattern, shape=config["trainable_mask"]["ap_shape"], center=config["trainable_mask"]["ap_center"], slm=config["trainable_mask"]["slm"], ) initial_mask = torch.from_numpy(initial_mask.astype(np.float32)) # prepare color filter if needed from waveprop.devices import slm_dict from waveprop.devices import SLMParam as SLMParam_wp slm_param = slm_dict[config["trainable_mask"]["slm"]] if ( config["trainable_mask"]["train_color_filter"] and SLMParam_wp.COLOR_FILTER in slm_param.keys() ): color_filter = slm_param[SLMParam_wp.COLOR_FILTER] color_filter = torch.from_numpy(color_filter.copy()).to(dtype=torch.float32) # TODO: add small random values? color_filter = color_filter + 0.1 * torch.rand_like(color_filter) else: raise ValueError( f"Initial PSF value {config['trainable_mask']['initial_value']} not supported" ) # convert to grayscale if needed if config["trainable_mask"]["grayscale"] and not is_grayscale(initial_mask): initial_mask = rgb2gray(initial_mask) mask = mask_class( initial_mask, downsample=downsample, color_filter=color_filter, **config["trainable_mask"], ) return mask