Source code for mpol.images

r"""The ``images`` module provides the core functionality of MPoL via
:class:`mpol.images.ImageCube`."""

from __future__ import annotations

from collections.abc import Callable

import numpy as np
from typing import Any
import numpy.typing as npt

import torch
import torch.fft  # to avoid conflicts with old torch.fft *function*
from torch import nn
import math

from mpol import constants, utils
from mpol.coordinates import GridCoords


[docs] class BaseCube(nn.Module): r""" A base cube of the same dimensions as the image cube. Designed to use a pixel mapping function :math:`f_\mathrm{map}` from the base cube values to the ImageCube domain. .. math:: I = f_\mathrm{map}(b) The ``base_cube`` pixel values are set as PyTorch `parameters <https://pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html>`_. Parameters ---------- coords : :class:`mpol.coordinates.GridCoords` an object instantiated from the GridCoords class, containing information about the image `cell_size` and `npix`. nchan : int the number of channels in the base cube. Default = 1. pixel_mapping : function a PyTorch function mapping the base pixel representation to the cube representation. If `None`, defaults to `torch.nn.Softplus()`. Output of the function should be in units of [:math:`\mathrm{Jy}\,\mathrm{arcsec}^{-2}`]. base_cube : torch.double tensor, optional a pre-packed base cube to initialize the model with. If None, assumes ``torch.zeros``. See :ref:`cube-orientation-label` for more information on the expectations of the orientation of the input image. """ def __init__( self, coords: GridCoords, nchan: int = 1, pixel_mapping: Callable[[torch.Tensor], torch.Tensor] | None = None, base_cube: torch.Tensor | None = None, ) -> None: super().__init__() self.coords = coords self.nchan = nchan # The ``base_cube`` is already packed to make the Fourier transformation easier if base_cube is None: # base_cube = -3 yields a nearly-blank cube after softplus, whereas # base_cube = 0.0 yields a cube with avg value of ~0.7, which is too high self.base_cube = nn.Parameter( -3 * torch.ones( (self.nchan, self.coords.npix, self.coords.npix), requires_grad=True, ) ) else: # We expect the user to supply a pre-packed base cube # so that it's ready to go for the FFT # We could apply this transformation for the user, but I think it will # lead to less confusion if we make this transformation explicit # for the user during the setup phase. self.base_cube = nn.Parameter(base_cube, requires_grad=True) if pixel_mapping is None: self.pixel_mapping: Callable[ [torch.Tensor], torch.Tensor ] = torch.nn.Softplus() else: self.pixel_mapping = pixel_mapping
[docs] def forward(self) -> torch.Tensor: r""" Calculate the image representation from the ``base_cube`` using the pixel mapping .. math:: I = f_\mathrm{map}(b) Returns : an image cube in units of [:math:`\mathrm{Jy}\,\mathrm{arcsec}^{-2}`]. """ return self.pixel_mapping(self.base_cube)
[docs] class HannConvCube(nn.Module): r""" This convolutional layer convolves an input cube by a small 3x3 filter with shape .. math:: \begin{bmatrix} 0.0625 & 0.1250 & 0.0625 \\ 0.1250 & 0.2500 & 0.1250 \\ 0.0625 & 0.1250 & 0.0625 \\ \end{bmatrix} which is the 2D version of the discretely-sampled response function corresponding to a Hann window, i.e., it is two 1D Hann windows multiplied together. This is a convolutional kernel in the image plane, and so effectively acts as apodization by a Hann window function in the Fourier domain. For more information, see the following Wikipedia articles on `Window Functions <https://en.wikipedia.org/wiki/Window_function>`_ in general and the `Hann Window <https://en.wikipedia.org/wiki/Hann_function>`_ specifically. The idea is that this layer would help naturally attenuate high spatial frequency artifacts by baking in a natural apodization in the Fourier plane. Args: nchan (int): number of channels requires_grad (bool): keep kernel fixed """ def __init__(self, nchan: int, requires_grad: bool = False) -> None: super().__init__() # simple convolutional filter operates on per-channel basis # 3x3 Hann filter self.m = nn.Conv2d( in_channels=nchan, out_channels=nchan, kernel_size=3, stride=1, groups=nchan, padding=1, # required to get same sized output for 3x3 kernel ) # weights has size (nchan, 1, 3, 3) # bias has shape (nchan) # build out the discretely-sampled Hann kernel spec = torch.tensor([0.25, 0.5, 0.25]) nugget = torch.outer(spec, spec) # shape (3,3) 2D Hann kernel exp = torch.unsqueeze(torch.unsqueeze(nugget, 0), 0) # shape (1, 1, 3, 3) weight = exp.repeat(nchan, 1, 1, 1) # shape (nchan, 1, 3, 3) # set the weight and bias self.m.weight = nn.Parameter( weight, requires_grad=requires_grad ) # set the (untunable) weight # set the bias to zero self.m.bias = nn.Parameter(torch.zeros(nchan), requires_grad=requires_grad)
[docs] def forward(self, cube: torch.Tensor) -> torch.Tensor: r"""Args: cube (torch.double tensor, of shape ``(nchan, npix, npix)``): a prepacked image cube, for example, from ImageCube.forward() Returns: torch.complex tensor: the FFT of the image cube, in packed format and of shape ``(nchan, npix, npix)`` """ # Conv2d is designed to work on batches, so some extra unsqueeze/squeezing # action is required. # Additionally, the convolution must be done on the *sky-oriented* cube sky_cube = utils.packed_cube_to_sky_cube(cube) # augment extra "batch" dimension to cube, to make it (1, nchan, npix, npix) aug_sky_cube = torch.unsqueeze(sky_cube, dim=0) # do convolution conv_aug_sky_cube = self.m(aug_sky_cube) # remove extra "batch" dimension # we're not using unsqueeze here, since there's a possibility that nchan=1 # and we want to keep that dimension conv_sky_cube = conv_aug_sky_cube[0] # return in packed format return utils.sky_cube_to_packed_cube(conv_sky_cube)
[docs] class GaussConvImage(nn.Module): r""" This convolutional layer will convolve the input cube with a 2D Gaussian kernel. The filter is the same for all channels in the input cube. Because the operation is carried out in the image domain, note that it may become computationally prohibitive for large kernel sizes. In that case, :class:`mpol.images.GaussConvFourier` may be preferred. Parameters ---------- coords : :class:`mpol.coordinates.GridCoords` an object instantiated from the GridCoords class, containing information about the image `cell_size` and `npix`. nchan : int the number of channels in the base cube. Default = 1. FWHM_maj: float, units of arcsec the FWHH of the Gaussian along the major axis FWHM_min: float, units of arcsec the FWHM of the Gaussian along the minor axis Omega: float, degrees the rotation of the major axis of the PSF, in degrees East of North. 0 degrees rotation has the major axis aligned in the North-South direction. requires_grad : bool Should the kernel parameters be trainable? Most applications will want to use `False`. """ def __init__( self, coords: GridCoords, nchan: int, FWHM_maj: float, FWHM_min: float, Omega: float = 0, requires_grad: bool = False, ) -> None: super().__init__() # convert FWHM to sigma and to radians FWHM2sigma = 1 / (2 * np.sqrt(2 * np.log(2))) # In this routine, x, y are used in the same sense as the GridCoords # object uses 'sky_x' and 'sky_y', i.e. x is l in arcseconds and # y is m in arcseconds. # assumes major axis along m direction at 0 degrees rotation. sigma_y = FWHM_maj * FWHM2sigma # arcsec sigma_x = FWHM_min * FWHM2sigma # arcsec # calculate filter out to some Gaussian width, and make a kernel with an # odd number of pixels limit = 3.0 * sigma_y npix_kernel = 1 + 2 * math.ceil(limit / coords.cell_size) if npix_kernel < 3: raise RuntimeError( """FWHM_maj is so small ({:} arcsec) relative to the cell_size ({:} arcsec) that the convolutional kernel would only be one pixel wide. Increase FWHM_maj or remove this convolutional layer entirely""".format( npix_kernel, coords.cell_size ) ) # create a grid to evaluate the 2D Gaussian, using an even number of # pixels with the kernel centered (no max pixel) kernel_centers = np.linspace(-limit, limit, num=npix_kernel) # [arcsec] # borrowed from GridCoords logic x_centers_2D = np.tile(kernel_centers, (npix_kernel, 1)) # [arcsec] sky_x_centers_2D = np.fliplr(x_centers_2D) sky_y_centers_2D = np.tile(kernel_centers, (npix_kernel, 1)).T # [arcsec] # evaluate Gaussian over grid gauss = utils.sky_gaussian_arcsec( sky_x_centers_2D, sky_y_centers_2D, 1.0, delta_x=0.0, delta_y=0.0, sigma_x=sigma_x, sigma_y=sigma_y, Omega=Omega, ) # normalize kernel to keep total flux the same gauss /= np.sum(gauss) nugget = torch.tensor(gauss, dtype=torch.float32) exp = torch.unsqueeze( torch.unsqueeze(nugget, 0), 0 ) # shape (1, 1, npix_kernel, npix_kernel) weight = exp.repeat( nchan, 1, 1, 1 ) # shape (nchan, 1, npix_kernel, npix_kernel) # groups = nchan will give us the minimal set of filters we need # somewhat confusingly, the neural network literature calls this # a "depthwise" convolution. I think that "depthwise" is not meant to imply # that there is now a consideration of the depth (e.g., color channel) # dimension when before there wasn't. # Rather, the emphasis is on the *wise*, as in "pairwise," in that # each depth channel is treated individually with its own filter, rather than # a filter that draws from multiple depth channels at once. # I think a better name is "channel-separate" convolution as indicated in the # "Understanding Deep Learning" textbook by Prince in Ch 10.6. # simple convolutional filter operates on per-channel basis self.m = nn.Conv2d( in_channels=nchan, out_channels=nchan, kernel_size=npix_kernel, stride=1, groups=nchan, padding="same", ) # weights has size (nchan, 1, npix_kernel, npix_kernel) # bias has shape (nchan) # set the weight and bias self.m.weight = nn.Parameter( weight, requires_grad=requires_grad ) # set the (untunable) weight # set the bias to zero self.m.bias = nn.Parameter(torch.zeros(nchan), requires_grad=requires_grad)
[docs] def forward(self, sky_cube: torch.Tensor) -> torch.Tensor: r"""Args: sky_cube (torch.double tensor, of shape ``(nchan, npix, npix)``): an image cube in sky format (note, not packed). Returns: torch.complex tensor: the FFT of the image cube, in sky format and of shape ``(nchan, npix, npix)`` """ convolved_sky: torch.Tensor convolved_sky = self.m(sky_cube) return convolved_sky
[docs] class GaussConvFourier(nn.Module): r""" This layer will convolve the input cube with a (potentially non-circular) Gaussian beam, using a Fourier strategy. The size of the beam is set upon initialization of the layer. Parameters ---------- coords : :class:`mpol.coordinates.GridCoords` an object instantiated from the GridCoords class, containing information about the image `cell_size` and `npix`. FWHM_maj: float, units of arcsec the FWHH of the Gaussian along the major axis FWHM_min: float, units of arcsec the FWHM of the Gaussian along the minor axis Omega: float, degrees the rotation of the major axis of the PSF, in degrees East of North. 0 degrees rotation has the major axis aligned in the North-South direction. """ def __init__( self, coords: GridCoords, FWHM_maj: float, FWHM_min: float, Omega: float = 0) -> None: super().__init__() self.coords = coords self.FWHM_maj = FWHM_maj self.FWHM_min = FWHM_min self.Omega = Omega taper_2D = uv_gaussian_taper(self.coords, self.FWHM_maj, self.FWHM_min, self.Omega) # store taper to register so it transfers to GPU self.register_buffer("taper_2D", torch.tensor(taper_2D, dtype=torch.float32))
[docs] def forward(self, packed_cube): r""" Convolve a packed_cube image with a 2D Gaussian PSF. Operation is carried out in the Fourier domain using a Gaussian taper. Parameters ---------- packed_cube : :class:`torch.Tensor` type shape ``(nchan, npix, npix)`` image cube in packed format. Returns ------- :class:`torch.Tensor` The convolved cube in packed format. """ nchan, npix_m, npix_l = packed_cube.size() assert ( (npix_m == self.coords.npix) and (npix_l == self.coords.npix) ), "packed_cube {:} does not have the same pixel dimensions as indicated by coords {:}".format( packed_cube.size(), self.coords.npix ) # in FFT packed format # we're round-tripping, so we can ignore prefactors for correctness # calling this `vis_like`, since it's not actually the vis vis_like = torch.fft.fftn(packed_cube, dim=(1, 2)) # apply taper to packed image tapered_vis = vis_like * torch.broadcast_to(self.taper_2D, packed_cube.size()) # iFFT back, ignoring prefactors for round-trip convolved_packed_cube = torch.fft.ifftn(tapered_vis, dim=(1, 2)) # assert imaginaries are effectively zero, otherwise something went wrong thresh = 1e-7 assert ( torch.max(convolved_packed_cube.imag) < thresh ), "Round-tripped image contains max imaginary value {:} > {:} threshold, something may be amiss.".format( torch.max(convolved_packed_cube.imag), thresh ) r_cube: torch.Tensor = convolved_packed_cube.real return r_cube
[docs] class ImageCube(nn.Module): r""" The parameter set is the pixel values of the image cube itself. The pixels are assumed to represent samples of the specific intensity and are given in units of [:math:`\mathrm{Jy}\,\mathrm{arcsec}^{-2}`]. All keyword arguments are required unless noted. The passthrough argument is essential for specifying whether the ImageCube object is the set of root parameters (``passthrough==False``) or if its simply a passthrough layer (``pasthrough==True``). In either case, ImageCube is essentially an identity layer, since no transformations are applied to the ``cube`` tensor. The main purpose of the ImageCube layer is to provide useful functionality around the ``cube`` tensor, such as returning a sky_cube representation and providing FITS writing functionality. In the case of ``passthrough==False``, the ImageCube layer also acts as a container for the trainable parameters. Parameters ---------- coords : :class:`mpol.coordinates.GridCoords` an object instantiated from the GridCoords class, containing information about the image `cell_size` and `npix`. nchan : int the number of channels in the base cube. Default = 1. """ def __init__( self, coords: GridCoords, nchan: int = 1, ) -> None: super().__init__() self.coords = coords self.nchan = nchan self.register_buffer("packed_cube", None)
[docs] def forward(self, packed_cube: torch.Tensor) -> torch.Tensor: r""" Pass the cube through as an identity operation, storing the value to the internal buffer. After the cube has been passed through, convenience instance attributes like `sky_cube` and `flux` will reflect the updated cube. Parameters ---------- packed_cube : :class:`torch.Tensor` of type :class:`torch.double` 3D torch tensor of shape ``(nchan, npix, npix)``) in 'packed' format Returns ------- :class:`torch.Tensor` tensor of shape ``(nchan, npix, npix)``), same as `cube` """ self.packed_cube = packed_cube return self.packed_cube
@property def sky_cube(self) -> torch.Tensor: """ The image cube arranged as it would appear on the sky. Returns: torch.double : 3D image cube of shape ``(nchan, npix, npix)`` """ return utils.packed_cube_to_sky_cube(self.packed_cube) @property def flux(self) -> torch.Tensor: """ The spatially-integrated flux of the image. Returns a 'spectrum' with the flux in each channel in units of Jy. Returns: torch.double: a 1D tensor of length ``(nchan)`` """ # convert from Jy/arcsec^2 to Jy/pixel using area of a pixel # multiply by arcsec^2/pixel return self.coords.cell_size**2 * torch.sum(self.packed_cube, dim=(1, 2))
[docs] def to_FITS( self, fname: str = "cube.fits", overwrite: bool = False, header_kwargs: dict | None = None, ) -> None: """ Export the image cube to a FITS file. Args: fname (str): the name of the FITS file to export to. overwrite (bool): if the file already exists, overwrite? header_kwargs (dict): Extra keyword arguments to write to the FITS header. Returns: None """ from astropy import wcs from astropy.io import fits w = wcs.WCS(naxis=2) w.wcs.crpix = np.array([1, 1]) w.wcs.cdelt = ( np.array([self.coords.cell_size, self.coords.cell_size]) / 3600 ) # decimal degrees w.wcs.ctype = ["RA---TAN", "DEC--TAN"] header = w.to_header() # add in the kwargs to the header if header_kwargs is not None: for k, v in header_kwargs.items(): header[k] = v hdu = fits.PrimaryHDU(utils.torch2npy(self.sky_cube), header=header) hdul = fits.HDUList([hdu]) hdul.writeto(fname, overwrite=overwrite) hdul.close()
[docs] def uv_gaussian_taper( coords: GridCoords, FWHM_maj: float, FWHM_min: float, Omega: float ) -> npt.NDArray[np.floating[Any]]: r""" Compute a packed Gaussian taper in the Fourier domain, to multiply against a packed visibility cube. While similar to :meth:`mpol.utils.fourier_gaussian_lambda_arcsec`, this routine delivers a visibility-plane taper with maximum amplitude normalized to 1.0. Parameters ---------- coords: :class:`mpol.coordinates.GridCoords` object indicating image and Fourier grid specifications. FWHM_maj: float, units of arcsec the FWHH of the Gaussian along the major axis FWHM_min: float, units of arcsec the FWHM of the Gaussian along the minor axis Omega: float, degrees the rotation of the major axis of the PSF, in degrees East of North. 0 degrees rotation has the major axis aligned in the East-West direction. Returns ------- :class:`np.ndarray` , shape ``(npix, npix)`` The Gaussian taper in packed format. """ # convert FWHM to sigma and to radians FWHM2sigma = 1 / (2 * np.sqrt(2 * np.log(2))) sigma_l = FWHM_maj * FWHM2sigma * constants.arcsec # radians sigma_m = FWHM_min * FWHM2sigma * constants.arcsec # radians u = coords.packed_u_centers_2D v = coords.packed_v_centers_2D # calculate primed rotated coordinates Omega_d = Omega * constants.deg up = u * np.cos(Omega_d) - v * np.sin(Omega_d) vp = u * np.sin(Omega_d) + v * np.cos(Omega_d) # calculate the Fourier Gaussian taper_2D: npt.NDArray[np.floating[Any]] taper_2D = np.exp(-2 * np.pi**2 * (sigma_l**2 * up**2 + sigma_m**2 * vp**2)) # a flux-conserving taper must have an amplitude of 1 at the origin. return taper_2D