Source code for sarepy.post.ring_removal_post

#============================================================================
# Copyright (c) 2018 Nghia T. Vo. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#============================================================================
# Author: Nghia T. Vo
# E-mail: nghia.vo@diamond.ac.uk
# Description: Methods for removing ring artifacts in reconstructed images by
# combining the polar transform and stripe removal methods.
# Publication date: 05/05/2020
#============================================================================

"""
Module for ring removal methods which work on reconstruction images
by combining the polar transform and stripe removal methods.
"""

import numpy as np
from scipy.ndimage import map_coordinates
import sarepy.prep.stripe_removal_former as srm


def cirle_mask(width, ratio):
    """
    Create a circle mask.

    Parameters
    -----------
    width : int
        Width of a square array.
    ratio : float
        Ratio between the diameter of the mask and the width of the array.

    Returns
    ------
    ndarray
         Square array.
    """
    mask = np.zeros((width, width), dtype=np.float32)
    center = width // 2
    radius = ratio * center
    y, x = np.ogrid[-center:width - center, -center:width - center]
    mask_check = x * x + y * y <= radius * radius
    mask[mask_check] = 1.0
    return mask


def reg_to_polar(width_reg, height_reg, width_pol, height_pol):
    """
    Generate coordinates of a rectangular grid from polar coordinates.

    Parameters
    -----------
    width_reg : int
        Width of the image in the Cartesian coordinate system.
    height_reg : int
        Height of the image in the Cartesian coordinate system.
    width_pol : int
        Width of the image in the polar coordinate system.
    height_pol : int
        Height of the image in the polar coordinate system.

    Returns
    ------
    x_mat : ndarray
         2D array. Broadcast of the x-coordinates.
    y_mat : ndarray
         2D array. Broadcast of the y-coordinates.

    """
    centerx = (width_reg - 1.0) * 0.5
    centery = (height_reg - 1.0) * 0.5
    r_max = np.floor(max(centerx, centery))
    r_list = np.linspace(0.0, r_max, width_pol)
    theta_list = np.arange(0.0, height_pol, 1.0) * 2 * np.pi / (height_pol - 1)
    r_mat, theta_mat = np.meshgrid(r_list, theta_list)
    x_mat = np.float32(
        np.clip(centerx + r_mat * np.cos(theta_mat), 0, width_reg - 1))
    y_mat = np.float32(
        np.clip(centery + r_mat * np.sin(theta_mat), 0, height_reg - 1))
    return x_mat, y_mat


def polar_to_reg(width_pol, height_pol, width_reg, height_reg):
    """
    Generate polar coordinates from grid coordinates.

    Parameters
    -----------
    width_pol : int
        Width of the image in the polar coordinate system.
    height_pol : int
        Height of the image in the polar coordinate system.
    width_reg : int
        Width of the image in the Cartesian coordinate system.
    height_reg : int
        Height of the image in the Cartesian coordinate system.

    Returns
    ------
    r_mat : ndarray
         2D array. Broadcast of the r-coordinates.
    theta_mat : ndarray
         2D array. Broadcast of the theta-coordinates.

    """
    centerx = (width_reg - 1.0) * 0.5
    centery = (height_reg - 1.0) * 0.5
    r_max = np.floor(max(centerx, centery))
    x_list = (1.0 * np.flipud(np.arange(width_reg)) -
              centerx) * width_pol / r_max
    y_list = (1.0 * np.flipud(np.arange(height_reg)) -
              centery) * width_pol / r_max
    x_mat, y_mat = np.meshgrid(x_list, y_list)
    r_mat = np.float32(np.clip(np.sqrt(x_mat**2 + y_mat**2), 0, width_pol - 1))
    theta_mat = np.float32(np.clip((np.pi + np.arctan2(y_mat, x_mat))
                                   * (height_pol - 1) / (2 * np.pi), 0, height_pol - 1))
    return r_mat, theta_mat


def mapping(mat, xmat, ymat):
    """
    Mapping an 2D array at coordinates (xmat, ymat) to rectangular
    grid coordinates.

    Parameters
    ----------
    mat : array_like
        2D array.
    xmat : array_like
        2D array of the x-coordinates.
    ymat : array_like
        2D array of the y-coordinates.

    Returns
    -------
    ndarray
        2D array.
    """
    indices = np.reshape(ymat, (-1, 1)), np.reshape(xmat, (-1, 1))
    mat = map_coordinates(mat, indices, order=1, mode='reflect')
    return mat.reshape(xmat.shape)


[docs]def ring_removal_based_wavelet_fft(mat, level=5, sigma=1.0, order=8, pad=150): """ Remove ring artifacts in the reconstructed image by combining the polar transform and the wavelet-fft-based method (Ref. [1]). Parameters ---------- mat : array_like Square array. Reconstructed image level : int Wavelet decomposition level. sigma : int Damping parameter. Larger is stronger. order : int Order of the the Daubechies wavelets. pad : int Padding for FFT Returns ------- ndarray Square array. Ring-removed image. References ---------- .. [1] https://doi.org/10.1364/OE.17.008567 """ (nrow, ncol) = mat.shape if nrow != ncol: raise ValueError( "Width and height of the reconstructed image are not the same") mask = cirle_mask(ncol, 1.0) (xmat, ymat) = reg_to_polar(ncol, ncol, ncol, ncol) (rmat, thetamat) = polar_to_reg(ncol, ncol, ncol, ncol) polar_mat = mapping(mat, xmat, ymat) polar_mat = srm.remove_stripe_based_wavelet_fft( polar_mat, level, sigma, order, pad) mat_rec = mapping(polar_mat, rmat, thetamat) return mat_rec * mask
[docs]def ring_removal_based_fft(mat, u=30, n=8, v=1, pad=150): """ Remove ring artifacts in the reconstructed image by combining the polar transform and the fft-based method (Ref. [1]). Parameters ---------- mat : array_like Square array. Reconstructed image u,n : int To define the shape of 1D Butterworth low-pass filter. v : int Number of rows (* 2) to be applied the filter. pad : int Padding for FFT Returns ------- ndarray Square array. Ring-removed image. References ---------- .. [1] https://doi.org/10.1063/1.1149043 """ (nrow, ncol) = mat.shape if nrow != ncol: raise ValueError( "Width and height of the reconstructed image are not the same") mask = cirle_mask(ncol, 1.0) (xmat, ymat) = reg_to_polar(ncol, ncol, ncol, ncol) (rmat, thetamat) = polar_to_reg(ncol, ncol, ncol, ncol) polar_mat = mapping(mat, xmat, ymat) polar_mat = srm.remove_stripe_based_fft(polar_mat, u, n, v, pad) mat_rec = mapping(polar_mat, rmat, thetamat) return mat_rec * mask