Source code for lensless.utils.io

# #############################################################################
# io.py
# =================
# Authors :
# Eric BEZZAM [ebezzam@gmail.com]
# #############################################################################


import os.path
import warnings

import cv2
import numpy as np
from PIL import Image

from lensless.hardware.constants import RPI_HQ_CAMERA_BLACK_LEVEL, RPI_HQ_CAMERA_CCM_MATRIX
from lensless.utils.image import bayer2rgb_cc, print_image_info, resize, rgb2gray, get_max_val
from lensless.utils.plot import plot_image


[docs]def load_image( fp, verbose=False, flip=False, flip_ud=False, flip_lr=False, bayer=False, black_level=RPI_HQ_CAMERA_BLACK_LEVEL, blue_gain=None, red_gain=None, ccm=RPI_HQ_CAMERA_CCM_MATRIX, back=None, nbits_out=None, as_4d=False, downsample=None, bg=None, return_float=False, shape=None, dtype=None, normalize=True, bgr_input=True, ): """ Load image as numpy array. Parameters ---------- fp : str Full path to file. verbose : bool, optional Whether to plot into about file. flip : bool Whether to flip data (vertical and horizontal). bayer : bool Whether input data is Bayer. blue_gain : float Blue gain for color correction. red_gain : float Red gain for color correction. black_level : float Black level. Default is to use that of Raspberry Pi HQ camera. ccm : :py:class:`~numpy.ndarray` Color correction matrix. Default is to use that of Raspberry Pi HQ camera. back : array_like Background level to subtract. nbits_out : int Output bit depth. Default is to use that of input. as_4d : bool Add depth and color dimensions if necessary so that image is 4D: (depth, height, width, color). downsample : int, optional Downsampling factor. Recommended for image reconstruction. bg : array_like Background level to subtract. return_float : bool Whether to return image as float array, or unsigned int. shape : tuple, optional Shape (H, W, C) to resize to. dtype : str, optional Data type of returned data. Default is to use that of input. normalize : bool, default True If ``return_float``, whether to normalize data to maximum value of 1. Returns ------- img : :py:class:`~numpy.ndarray` RGB image of dimension (height, width, 3). """ assert os.path.isfile(fp) nbits = None # input bit depth if "dng" in fp: import rawpy assert bayer raw = rawpy.imread(fp) img = raw.raw_image # # # TODO : use raw.postprocess? to much unknown processing... # img = raw.postprocess( # adjust_maximum_thr=0, # default 0.75 # no_auto_scale=False, # # no_auto_scale=True, # gamma=(1, 1), # bright=1, # default 1 # exp_shift=1, # no_auto_bright=True, # # use_camera_wb=True, # # use_auto_wb=False, # # -- gives better balance for PSF measurement # use_camera_wb=False, # use_auto_wb=True, # default is False? f both use_camera_wb and use_auto_wb are True, then use_auto_wb has priority. # ) # if red_gain is None or blue_gain is None: # camera_wb = raw.camera_whitebalance # red_gain = camera_wb[0] # blue_gain = camera_wb[1] nbits = int(np.ceil(np.log2(raw.white_level))) ccm = raw.color_matrix[:, :3] black_level = np.array(raw.black_level_per_channel[:3]).astype(np.float32) elif "npy" in fp or "npz" in fp: img = np.load(fp) else: img = cv2.imread(fp, cv2.IMREAD_UNCHANGED) if bayer: assert len(img.shape) == 2, img.shape if nbits is None: if img.max() > 255: # HQ camera nbits = 12 else: nbits = 8 if back: back_img = cv2.imread(back, cv2.IMREAD_UNCHANGED) dtype = img.dtype img = img.astype(np.float32) - back_img.astype(np.float32) img = np.clip(img, a_min=0, a_max=img.max()) img = img.astype(dtype) if nbits_out is None: nbits_out = nbits img = bayer2rgb_cc( img, nbits=nbits, blue_gain=blue_gain, red_gain=red_gain, black_level=black_level, ccm=ccm, nbits_out=nbits_out, ) else: if len(img.shape) == 3 and bgr_input: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) original_dtype = img.dtype if flip: img = np.flipud(img) img = np.fliplr(img) if flip_ud: img = np.flipud(img) if flip_lr: img = np.fliplr(img) if bg is not None: # if bg is float vector, turn into int-valued vector if bg.max() <= 1 and img.dtype not in [np.float32, np.float64]: bg = bg * get_max_val(img) img = img - bg img = np.clip(img, a_min=0, a_max=img.max()) if as_4d: if len(img.shape) == 3: img = img[np.newaxis, :, :, :] elif len(img.shape) == 2: img = img[np.newaxis, :, :, np.newaxis] if downsample is not None or shape is not None: if downsample is not None: factor = 1 / downsample else: factor = None img = resize(img, factor=factor, shape=shape) if return_float: if dtype is None: dtype = np.float32 assert dtype == np.float32 or dtype == np.float64 img = img.astype(dtype) if normalize: img /= img.max() else: if dtype is None: dtype = original_dtype img = img.astype(dtype) if verbose: print_image_info(img) return img
[docs]def load_psf( fp, downsample=1, return_float=True, bg_pix=(5, 25), return_bg=False, flip=False, flip_ud=False, flip_lr=False, verbose=False, bayer=False, blue_gain=None, red_gain=None, dtype=np.float32, nbits_out=None, single_psf=False, shape=None, use_3d=False, bgr_input=True, force_rgb=False, ): """ Load and process PSF for analysis or for reconstruction. Basic steps are: * Load image. * (Optionally) subtract background. Recommended. * (Optionally) resize to more manageable size * (Optionally) normalize within [0, 1] if using for reconstruction; otherwise cast back to uint for analysis. Parameters ---------- fp : str Full path to file. downsample : int, optional Downsampling factor. Recommended for image reconstruction. return_float : bool, optional Whether to return PSF as float array, or unsigned int. bg_pix : tuple, optional Section of pixels to take from top left corner to remove background level. Set to `None` to omit this step, althrough it is highly recommended. return_bg : bool, optional Whether to return background level, for removing from data for reconstruction. flip : bool, optional Whether to flip up-down and left-right. verbose : bool Whether to print metadata. bayer : bool Whether input data is Bayer. blue_gain : float Blue gain for color correction. red_gain : float Red gain for color correction. dtype : float32 or float64 Data type of returned data. nbits_out : int Output bit depth. Default is to use that of input. single_psf : bool Whether to sum RGB channels into single PSF, same across channels. Done in "Learned reconstructions for practical mask-based lensless imaging" of Kristina Monakhova et. al. Returns ------- psf : :py:class:`~numpy.ndarray` 4-D array of PSF. """ # load image data and extract necessary channels if use_3d: assert os.path.isfile(fp) if fp.endswith(".npy"): psf = np.load(fp) elif fp.endswith(".npz"): archive = np.load(fp) if len(archive.files) > 1: print("Warning: more than one array in .npz archive, using first") elif len(archive.files) == 0: raise ValueError("No arrays in .npz archive") psf = np.load(fp)[archive.files[0]] else: raise ValueError("File format not supported") else: psf = load_image( fp, verbose=False, flip=flip, flip_ud=flip_ud, flip_lr=flip_lr, bayer=bayer, blue_gain=blue_gain, red_gain=red_gain, nbits_out=nbits_out, bgr_input=bgr_input, ) original_dtype = psf.dtype max_val = get_max_val(psf) psf = np.array(psf, dtype=dtype) if force_rgb: if len(psf.shape) == 2: psf = np.stack([psf] * 3, axis=2) elif len(psf.shape) == 3: pass if use_3d: if len(psf.shape) == 3: grayscale = True psf = psf[:, :, :, np.newaxis] else: assert len(psf.shape) == 4 grayscale = False else: if len(psf.shape) == 3: grayscale = False psf = psf[np.newaxis, :, :, :] else: assert len(psf.shape) == 2 grayscale = True psf = psf[np.newaxis, :, :, np.newaxis] # check that all depths of the psf have the same shape. for i in range(len(psf)): assert psf[0].shape == psf[i].shape # subtract background, assume black edges if bg_pix is None: bg = np.zeros(len(np.shape(psf))) else: # grayscale if grayscale: bg = np.mean(psf[:, bg_pix[0] : bg_pix[1], bg_pix[0] : bg_pix[1], :]) psf -= bg # rgb else: bg = [] for i in range(psf.shape[3]): bg_i = np.mean(psf[:, bg_pix[0] : bg_pix[1], bg_pix[0] : bg_pix[1], i]) psf[:, :, :, i] -= bg_i bg.append(bg_i) psf = np.clip(psf, a_min=0, a_max=psf.max()) bg = np.array(bg) # resize if downsample != 1 or shape is not None: psf = resize(psf, shape=shape, factor=1 / downsample) if single_psf: if not grayscale: # TODO : in Lensless Learning, they sum channels --> `psf_diffuser = np.sum(psf_diffuser,2)` # https://github.com/Waller-Lab/LenslessLearning/blob/master/pre-trained%20reconstructions.ipynb psf = np.sum(psf, axis=3) psf = psf[:, :, :, np.newaxis] else: warnings.warn("Notice : single_psf has no effect for grayscale psf") single_psf = False # normalize if return_float: # psf /= psf.max() psf /= np.linalg.norm(psf.ravel()) bg /= max_val else: psf = psf.astype(original_dtype) if verbose: print_image_info(psf) if return_bg: return psf, bg else: return psf
[docs]def load_data( psf_fp, data_fp, background_fp=None, return_bg=False, remove_background=True, return_float=True, downsample=None, bg_pix=(5, 25), plot=True, flip=False, flip_ud=False, flip_lr=False, bayer=False, blue_gain=None, red_gain=None, gamma=None, gray=False, dtype=None, single_psf=False, shape=None, use_torch=False, torch_device="cpu", normalize=False, bgr_input=True, ): """ Load data for image reconstruction. Parameters ---------- psf_fp : str Full path to PSF file. data_fp : str Full path to measurement file. return_float : bool, optional Whether to return PSF as float array, or unsigned int. downsample : int or float Downsampling factor. bg_pix : tuple, optional Section of pixels to take from top left corner to remove background level. Set to `None` to omit this step, although it is highly recommended. plot : bool, optional Whether or not to plot PSF and raw data. flip : bool Whether to flip data (vertical and horizontal). bayer : bool Whether input data is Bayer. blue_gain : float Blue gain for color correction. red_gain : float Red gain for color correction. gamma : float, optional Optional gamma factor to apply, ONLY for plotting. Default is None. gray : bool Whether to load as grayscale or RGB. dtype : float32 or float64, default float32 Data type of returned data. single_psf : bool Whether to sum RGB channels into single PSF, same across channels. Done in "Learned reconstructions for practical mask-based lensless imaging" of Kristina Monakhova et. al. normalize : bool default True Whether to normalize data to maximum value of 1. Returns ------- psf : :py:class:`~numpy.ndarray` 2-D array of PSF. data : :py:class:`~numpy.ndarray` 2-D array of raw measurement data. """ assert os.path.isfile(psf_fp) assert os.path.isfile(data_fp) if shape is None: assert downsample is not None if dtype is None: dtype = np.float32 elif dtype == "float32": dtype = np.float32 elif dtype == "float64": dtype = np.float64 else: raise ValueError("dtype must be float32 or float64") use_3d = psf_fp.endswith(".npy") or psf_fp.endswith(".npz") # load and process PSF data bg = None res = load_psf( psf_fp, downsample=downsample, return_float=return_float, bg_pix=bg_pix, return_bg=True if bg_pix is not None else False, flip=flip, flip_ud=flip_ud, flip_lr=flip_lr, bayer=bayer, blue_gain=blue_gain, red_gain=red_gain, dtype=dtype, single_psf=single_psf, shape=shape, use_3d=use_3d, bgr_input=bgr_input, ) if bg_pix is not None: psf, bg = res else: psf = res # load and process raw measurement data = load_image( data_fp, flip=flip, flip_ud=flip_ud, flip_lr=flip_lr, bayer=bayer, blue_gain=blue_gain, red_gain=red_gain, bg=bg, as_4d=True, return_float=return_float, shape=shape, normalize=normalize if background_fp is None else False, bgr_input=bgr_input, ) if background_fp is not None: bg = load_image( background_fp, flip=flip, bayer=bayer, blue_gain=blue_gain, red_gain=red_gain, as_4d=True, return_float=return_float, shape=shape, normalize=False, bgr_input=bgr_input, ) assert bg.shape == data.shape if remove_background: data -= bg # clip to 0 data = np.clip(data, a_min=0, a_max=data.max()) if normalize: data /= data.max() bg /= data.max() # to normalize by the same factor if data.shape != psf.shape: # in DiffuserCam dataset, images are already reshaped data = resize(data, shape=psf.shape) if background_fp is not None: bg = resize(bg, shape=psf.shape) if data.shape[3] > 1 and psf.shape[3] == 1: warnings.warn( "Warning: loaded a grayscale PSF with RGB data. Repeating PSF across channels." "This may be an error as the PSF and the data are likely from different datasets." ) psf = np.repeat(psf, data.shape[3], axis=3) if data.shape[3] == 1 and psf.shape[3] > 1: warnings.warn( "Warning: loaded a RGB PSF with grayscale data. Repeating data across channels." "This may be an error as the PSF and the data are likely from different datasets." ) data = np.repeat(data, psf.shape[3], axis=3) if data.shape[3] != psf.shape[3]: raise ValueError( "PSF and data must have same number of channels, check that they are from the same dataset." ) if gray: psf = np.array(rgb2gray(psf), np.newaxis) data = np.array(rgb2gray(data), np.newaxis) if plot: ax = plot_image(psf[0], gamma=gamma) ax.set_title("PSF of the first depth") ax = plot_image(data[0], gamma=gamma) ax.set_title("Raw data") psf = np.array(psf, dtype=dtype) data = np.array(data, dtype=dtype) bg = np.array(bg, dtype=dtype) if use_torch: import torch if dtype == np.float32: torch_dtype = torch.float32 elif dtype == np.float64: torch_dtype = torch.float64 psf = torch.from_numpy(psf).type(torch_dtype).to(torch_device) data = torch.from_numpy(data).type(torch_dtype).to(torch_device) bg = torch.from_numpy(bg).type(torch_dtype).to(torch_device) if return_bg: return psf, data, bg else: return psf, data
[docs]def save_image(img, fp, max_val=255, normalize=True): """Save as uint8 image.""" img_tmp = img.copy() if normalize: if img_tmp.dtype == np.uint16 or img_tmp.dtype == np.uint8: img_tmp = img_tmp.astype(np.float32) img_tmp -= img_tmp.min() img_tmp /= img_tmp.max() img_tmp *= max_val img_tmp = img_tmp.astype(np.uint8) else: if img_tmp.dtype == np.float64 or img_tmp.dtype == np.float32: # check within [0, 1] and convert to uint8 normalized = False if img_tmp.min() < 0: img_tmp -= img_tmp.min() normalized = True if img_tmp.max() > 1: img_tmp /= img_tmp.max() normalized = True if normalized: print(f"Warning (out of range): {fp} normalizing data to [0, 1]") img_tmp *= max_val img_tmp = img_tmp.astype(np.uint8) # save if len(img_tmp.shape) == 3 and img_tmp.shape[2] == 3: # RGB img_tmp = Image.fromarray(img_tmp) else: # grayscale img_tmp = Image.fromarray(img_tmp.squeeze()) img_tmp.save(fp)
[docs]def get_dtype(dtype=None, is_torch=False): """ Get dtype for numpy or torch. Parameters ---------- dtype : str, optional "float32" or "float64", Default is "float32". is_torch : bool, optional Whether to return torch dtype. """ if dtype is None: dtype = "float32" assert dtype == "float32" or dtype == "float64" if is_torch: import torch if dtype is None: if is_torch: dtype = torch.float32 else: dtype = np.float32 else: if is_torch: dtype = torch.float32 if dtype == "float32" else torch.float64 else: dtype = np.float32 if dtype == "float32" else np.float64 return dtype
[docs]def get_ctypes(dtype, is_torch): if not is_torch: if dtype == np.float32 or dtype == np.complex64: return np.complex64, np.complex64 elif dtype == np.float64 or dtype == np.complex128: return np.complex128, np.complex128 else: raise ValueError("Unexpected dtype: ", dtype) else: import torch if dtype == np.float32 or dtype == np.complex64: return torch.complex64, np.complex64 elif dtype == np.float64 or dtype == np.complex128: return torch.complex128, np.complex128 elif dtype == torch.float32 or dtype == torch.complex64: return torch.complex64, np.complex64 elif dtype == torch.float64 or dtype == torch.complex128: return torch.complex128, np.complex128 else: raise ValueError("Unexpected dtype: ", dtype)