Source code for lensless.eval.benchmark

# #############################################################################
# benchmark.py
# =================
# Authors :
# Yohann PERRON
# Eric BEZZAM [ebezzam@gmail.com]
# #############################################################################


from lensless.utils.dataset import DiffuserCamTestDataset
from lensless.utils.io import save_image
from waveprop.noise import add_shot_noise
from tqdm import tqdm
import os
import numpy as np
import wandb

try:
    import torch
    from torch.utils.data import DataLoader
    from torch.nn import MSELoss, L1Loss
    from torchmetrics import StructuralSimilarityIndexMeasure
    from torchmetrics.image import lpip, psnr
    from torch.optim import SGD
except ImportError:
    raise ImportError(
        "Torch, torchvision, and torchmetrics are needed to benchmark reconstruction algorithm."
    )


# TODO put Parameterize and Perturb functions in a separate file
def get_param(model):
    all_param = []
    for _, param in model.named_parameters():
        all_param.append(param)
    all_param = torch.cat([param.flatten() for param in all_param])
    return all_param


def pnp_loss(y_est, y, recon, mu, original_param):
    new_param = get_param(recon)
    loss = torch.mean((y_est - y) ** 2) + mu * torch.mean((new_param - original_param) ** 2)
    return loss


def swap_color_channels(tensor, channels):
    if len(channels) == 2:
        color1 = tensor[..., channels[0]].copy()
        tensor[..., channels[0]] = tensor[..., channels[1]]
        tensor[..., channels[1]] = color1
    elif len(channels) == 3:
        # reorder channels
        color1 = tensor[..., channels[0]].copy()
        color2 = tensor[..., channels[1]].copy()
        color3 = tensor[..., channels[2]].copy()
        tensor[..., 0] = color1
        tensor[..., 1] = color2
        tensor[..., 2] = color3
    return tensor


[docs]def benchmark( model, dataset, batchsize=1, metrics=None, crop=None, save_idx=None, save_intermediate=False, output_dir=None, unrolled_output_factor=False, pre_process_aux=False, return_average=True, snr=None, use_wandb=False, label=None, epoch=None, use_background=True, pnp=None, swap_channels=False, **kwargs, ): """ Compute multiple metrics for a reconstruction algorithm. Parameters ---------- model : :py:class:`~lensless.ReconstructionAlgorithm` Reconstruction algorithm to benchmark. dataset : :py:class:`~lensless.benchmark.ParallelDataset` Parallel dataset of lensless and lensed images. batchsize : int, optional Batch size for processing. For maximum compatibility use 1 (batchsize above 1 are not supported on all algorithm), by default 1 metrics : dict, optional Dictionary of metrics to compute. If None, MSE, MAE, SSIM, LPIPS and PSNR are computed. save_idx : list of int, optional List of indices to save the predictions, by default None (not to save any). output_dir : str, optional Directory to save the predictions, by default save in working directory if save_idx is provided. crop : dict, optional Dictionary of crop parameters (vertical: [start, end], horizontal: [start, end]), by default None (no crop). unrolled_output_factor : bool, optional If True, compute metrics for unrolled output, by default False. return_average : bool, optional If True, return the average value of the metrics, by default True. snr : float, optional Signal to noise ratio for adding shot noise. If None, no noise is added, by default None. use_background: bool, optional If dataset has background, use it for reconstruction, by default True. pnp : dict, optional Dictionary of parameters for (Parameterize and perturb) algorithm, by default None. Required keys: "mu" (distance from original parameters), "lr" (SGD learning rate), "n_iter" (number of iterations), "model_path" (original model path). Returns ------- Dict[str, float] A dictionary containing the metrics name and average value """ assert isinstance(model._psf, torch.Tensor), "model need to be constructed with torch support" device = model._psf.device if output_dir is None: output_dir = os.getcwd() else: output_dir = str(output_dir) if not os.path.exists(output_dir): os.mkdir(output_dir) if pnp is not None: assert "mu" in pnp, "mu must be provided" assert "lr" in pnp, "lr must be provided" assert "n_iter" in pnp, "n_iter must be provided" assert "model_path" in pnp, "model_path must be provided" assert isinstance(pnp["mu"], float), "mu must be a float" assert isinstance(pnp["lr"], float), "lr must be a float" assert isinstance(pnp["n_iter"], int), "n_iter must be an int" assert isinstance(pnp["model_path"], str), "model_path must be a string" assert batchsize == 1, "batchsize must be 1 for parameterize and perturb" original_param = get_param(model) if metrics is None: metrics = { "MSE": MSELoss(reduction="mean").to(device), "LPIPS_Vgg": lpip.LearnedPerceptualImagePatchSimilarity( net_type="vgg", normalize=True, reduction="sum" ).to(device), # "LPIPS_Alex": lpip.LearnedPerceptualImagePatchSimilarity( # net_type="alex", normalize=True # ).to(device), "PSNR": psnr.PeakSignalNoiseRatio(reduction=None, dim=(1, 2, 3), data_range=(0, 1)).to( device ), "SSIM": StructuralSimilarityIndexMeasure(reduction=None, data_range=(0, 1)).to(device), "ReconstructionError": None, } metrics_values = {key: [] for key in metrics} if unrolled_output_factor: output_metrics = metrics.keys() for key in output_metrics: if key != "ReconstructionError": metrics_values[key + "_unrolled"] = [] if pre_process_aux: metrics_values["ReconstructionError_PreProc"] = [] output_intermediate = unrolled_output_factor or pre_process_aux or save_intermediate # loop over batches dataloader = DataLoader(dataset, batch_size=batchsize, pin_memory=(device != "cpu")) model.reset() idx = 0 weights = [] # for averaging batches for batch in tqdm(dataloader): weights.append(len(batch[0])) flip_lr = None flip_ud = None background = None lensless = batch[0].to(device) lensed = batch[1].to(device) if dataset.measured_bg and use_background: background = batch[-1].to(device) if dataset.multimask or dataset.random_flip: psfs = batch[2] psfs = psfs.to(device) else: psfs = None if dataset.random_flip: flip_lr = batch[3] flip_ud = batch[4] # add shot noise if snr is not None: for i in range(lensless.shape[0]): lensless[i] = add_shot_noise(lensless[i], float(snr)) # compute predictions if batchsize == 1: if pnp is not None: from lensless.recon.model_dict import load_model from lensless.recon.rfft_convolve import RealFFTConvolve2D psf = psfs[0].to(device) recon_pnp = load_model(pnp["model_path"], psf, device=device, verbose=False) # define optimizer optimizer = SGD(recon_pnp.parameters(), lr=pnp["lr"]) # create forward model (which needs padding) forward_model_param = recon_pnp._convolver_param.copy() forward_model_param["pad"] = True forward_model = RealFFTConvolve2D(psf, **forward_model_param) # iterate / minimize loss for i in range(pnp["n_iter"]): optimizer.zero_grad() recon_pnp.set_data(lensless) prediction = recon_pnp.apply( disp_iter=-1, save=False, gamma=None, plot=False, ) # simulate measurement y_est = forward_model.convolve(prediction) # -- normalize y_est = y_est - y_est.min() y_est = y_est / y_est.max() loss = pnp_loss(y_est, lensless, recon_pnp, pnp["mu"], original_param) loss.backward() optimizer.step() # if i % 10 == 0: # print(f"Loss {i}: {loss}") # prediction = recon_pnp.apply( # disp_iter=-1, # save=False, # gamma=None, # plot=False, # ).detach() prediction = prediction.detach() else: with torch.no_grad(): if psfs is not None: model._set_psf(psfs[0]) model.set_data(lensless) prediction = model.apply( plot=False, save=False, output_intermediate=output_intermediate, background=background, **kwargs, ) else: with torch.no_grad(): prediction = model.forward( batch=lensless, psfs=psfs, background=background, **kwargs ) if output_intermediate: psfs_out = prediction[3] pre_process_out = prediction[2] unrolled_out = prediction[1] prediction = prediction[0] prediction_original = prediction.clone() # Convert to [N*D, C, H, W] for torchmetrics prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3) lensed = lensed.reshape(-1, *lensed.shape[-3:]).movedim(-1, -3) if hasattr(dataset, "alignment"): if dataset.alignment is not None: prediction = dataset.extract_roi( prediction, axis=(-2, -1), flip_lr=flip_lr, flip_ud=flip_ud ) else: prediction, lensed = dataset.extract_roi( prediction, axis=(-2, -1), lensed=lensed, flip_lr=flip_lr, flip_ud=flip_ud ) elif crop is not None: assert flip_lr is None and flip_ud is None prediction = prediction[ ..., crop["vertical"][0] : crop["vertical"][1], crop["horizontal"][0] : crop["horizontal"][1], ] lensed = lensed[ ..., crop["vertical"][0] : crop["vertical"][1], crop["horizontal"][0] : crop["horizontal"][1], ] if save_idx is not None: for i, _batch_idx in enumerate(np.arange(idx, idx + batchsize)): if _batch_idx in save_idx: prediction_np = prediction.cpu().numpy()[i] # switch to [H, W, C] for saving prediction_np = np.moveaxis(prediction_np, 0, -1) fp = os.path.join(output_dir, f"{_batch_idx}.png") save_image(prediction_np, fp=fp) if save_intermediate: fp = os.path.join(output_dir, f"{_batch_idx}_inv.png") unrolled_out_np = unrolled_out.cpu().numpy()[i].squeeze() # -- swap red and green channels if swap_channels: unrolled_out_np = swap_color_channels(unrolled_out_np, swap_channels) save_image(unrolled_out_np, fp=fp) fp = os.path.join(output_dir, f"{_batch_idx}_preproc.png") pre_process_out_np = pre_process_out.cpu().numpy()[i].squeeze() # -- swap red and green channels if swap_channels: pre_process_out_np = swap_color_channels( pre_process_out_np, swap_channels ) save_image(pre_process_out_np, fp=fp) if psfs_out is not None: fp = os.path.join(output_dir, f"{_batch_idx}_psfs.png") if psfs_out.shape[0] == 1: psfs_out_np = psfs_out.cpu().numpy().squeeze() else: psfs_out_np = psfs_out.cpu().numpy()[i].squeeze() # -- swap red and green channels if swap_channels: psfs_out_np = swap_color_channels(psfs_out_np, swap_channels) save_image(psfs_out_np, fp=fp) if use_wandb: assert epoch is not None, "epoch must be provided for wandb logging" log_key = f"{_batch_idx}_{label}" if label is not None else f"{_batch_idx}" wandb.log({log_key: wandb.Image(fp)}, step=epoch) # normalization prediction_max = torch.amax(prediction, dim=(-1, -2, -3), keepdim=True) if torch.all(prediction_max != 0): prediction = prediction / prediction_max else: print("Warning: prediction is zero") lensed_max = torch.amax(lensed, dim=(1, 2, 3), keepdim=True) lensed = lensed / lensed_max # compute metrics for metric in metrics: if metric == "ReconstructionError": metrics_values[metric] += model.reconstruction_error( prediction=prediction_original, lensless=lensless, psfs=psfs ).tolist() else: try: if "LPIPS" in metric: if prediction.shape[1] == 1: # LPIPS needs 3 channels metrics_values[metric].append( metrics[metric]( prediction.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1) ) .cpu() .item() ) else: metrics_values[metric].append( metrics[metric](prediction, lensed).cpu().item() ) elif metric == "MSE": metrics_values[metric].append( metrics[metric](prediction, lensed).cpu().item() * len(batch[0]) ) else: vals = metrics[metric](prediction, lensed).cpu() if hasattr(vals.tolist(), "__len__"): metrics_values[metric] += vals.tolist() else: metrics_values[metric].append(vals.item()) except Exception as e: print(f"Error in metric {metric}: {e}") # compute metrics for unrolled output if unrolled_output_factor: # -- convert to CHW and remove depth unrolled_out = unrolled_out.reshape(-1, *unrolled_out.shape[-3:]).movedim(-1, -3) # -- extraction region of interest if hasattr(dataset, "alignment"): if dataset.alignment is not None: unrolled_out = dataset.extract_roi(unrolled_out, axis=(-2, -1)) else: unrolled_out = dataset.extract_roi( unrolled_out, axis=(-2, -1), # lensed=lensed # lensed already extracted before ) assert np.all(lensed.shape == unrolled_out.shape) elif crop is not None: unrolled_out = unrolled_out[ ..., crop["vertical"][0] : crop["vertical"][1], crop["horizontal"][0] : crop["horizontal"][1], ] # -- normalization unrolled_out_max = torch.amax(unrolled_out, dim=(-1, -2, -3), keepdim=True) if torch.all(unrolled_out_max != 0): unrolled_out = unrolled_out / unrolled_out_max # -- compute metrics for metric in metrics: if metric == "ReconstructionError": # only have this for final output continue else: if "LPIPS" in metric: if unrolled_out.shape[1] == 1: # LPIPS needs 3 channels metrics_values[metric + "_unrolled"].append( metrics[metric]( unrolled_out.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1) ) .cpu() .item() ) else: metrics_values[metric + "_unrolled"].append( metrics[metric](unrolled_out, lensed).cpu().item() ) elif metric == "MSE": metrics_values[metric + "_unrolled"].append( metrics[metric](unrolled_out, lensed).cpu().item() * len(batch[0]) ) else: vals = metrics[metric](unrolled_out, lensed).cpu() if hasattr(vals.tolist(), "__len__"): metrics_values[metric + "_unrolled"] += vals.tolist() else: metrics_values[metric + "_unrolled"].append(vals.item()) # compute metrics for pre-processed output if pre_process_aux: metrics_values["ReconstructionError_PreProc"] += model.reconstruction_error( prediction=prediction_original, lensless=pre_process_out ).tolist() model.reset() idx += batchsize # average metrics if return_average: for metric in metrics_values.keys(): if "MSE" in metric or "LPIPS" in metric: # differently because metrics are grouped into bathces metrics_values[metric] = np.sum(metrics_values[metric]) / len(dataset) else: metrics_values[metric] = np.mean(metrics_values[metric]) return metrics_values
if __name__ == "__main__": from lensless import ADMM downsample = 1.0 batchsize = 1 n_files = 10 n_iter = 100 # check if GPU is available if torch.cuda.is_available(): device = "cuda" else: device = "cpu" # prepare dataset dataset = DiffuserCamTestDataset(n_files=n_files, downsample=downsample) # prepare model psf = dataset.psf.to(device) model = ADMM(psf, n_iter=n_iter) # run benchmark print(benchmark(model, dataset, batchsize=batchsize))