Source code for mpol.datasets

from __future__ import annotations

from typing import Any

import numpy as np
import torch
from numpy import floating, integer
from numpy.typing import ArrayLike, NDArray

from mpol import utils
from mpol.coordinates import GridCoords


[docs] class GriddedDataset(torch.nn.Module): r""" Parameters ---------- coords : :class:`~mpol.coordinates.GridCoords` If providing this, cannot provide ``cell_size`` or ``npix``. vis_gridded : :class:`torch.Tensor` of :class:`torch.complex128` the gridded visibility data stored in a "packed" format (pre-shifted for fft) weight_gridded : :class:`torch.Tensor` the weights corresponding to the gridded visibility data, also in a packed format mask : :class:`torch.Tensor` of :class:`torch.bool` a boolean mask to index the non-zero locations of ``vis_gridded`` and ``weight_gridded`` in their packed format. nchan : int the number of channels in the image (default = 1). After initialization, the GriddedDataset provides the non-zero cells of the gridded visibilities and weights as a 1D vector via the following instance variables. This means that any individual channel information has been collapsed. :ivar vis_indexed: 1D complex tensor of visibility data :ivar weight_indexed: 1D tensor of weight values If you index the output of the Fourier layer in the same manner using ``self.mask``, then the model and data visibilities can be directly compared using a loss function. """ def __init__( self, *, coords: GridCoords, vis_gridded: torch.Tensor, weight_gridded: torch.Tensor, mask: torch.Tensor, nchan: int = 1, ) -> None: super().__init__() self.coords = coords self.nchan = nchan # store variables as buffers of the module self.register_buffer("vis_gridded", vis_gridded) self.register_buffer("weight_gridded", weight_gridded) self.register_buffer("mask", mask) self.vis_gridded: torch.Tensor self.weight_gridded: torch.Tensor self.mask: torch.Tensor # pre-index the values # note that these are *collapsed* across all channels # 1D array self.register_buffer("vis_indexed", self.vis_gridded[self.mask]) self.register_buffer("weight_indexed", self.weight_gridded[self.mask]) self.vis_indexed: torch.Tensor self.weight_indexed: torch.Tensor
[docs] def add_mask( self, mask: ArrayLike, ) -> None: r""" Apply an additional mask to the data. Only works as a data limiting operation (i.e., ``mask`` is more restrictive than the mask already attached to the dataset). Args: mask (2D numpy or PyTorch tensor): boolean mask (in packed format) to apply to dataset. Assumes input will be broadcast across all channels. """ new_2D_mask = torch.Tensor(mask).detach() new_3D_mask = torch.broadcast_to(new_2D_mask, self.mask.size()) # update mask via an AND operation, we will only keep visibilities that are # 1) part of the original dataset # 2) valid within the new mask self.mask = torch.logical_and(self.mask, new_3D_mask) # zero out vis_gridded and weight_gridded that may have existed # but are no longer valid # These operations on the gridded quantities are only important for routines # that grab these quantities directly, like residual grid imager self.vis_gridded[~self.mask] = 0.0 self.weight_gridded[~self.mask] = 0.0 # update pre-indexed values self.vis_indexed = self.vis_gridded[self.mask] self.weight_indexed = self.weight_gridded[self.mask]
[docs] def forward(self, modelVisibilityCube: torch.Tensor) -> torch.Tensor: """ Args: modelVisibilityCube (complex torch.tensor): with shape ``(nchan, npix, npix)`` to be indexed. In "pre-packed" format, as in output from :meth:`mpol.fourier.FourierCube.forward()` Returns: torch complex tensor: 1d torch tensor of indexed model samples collapsed across cube dimensions. """ assert ( modelVisibilityCube.size()[0] == self.mask.size()[0] ), "vis and dataset mask do not have the same number of channels." # As of Pytorch 1.7.0, complex numbers are partially supported. # However, masked_select does not yet work (with gradients) # on the complex vis, so hence this awkward step of selecting # the reals and imaginaries separately re = modelVisibilityCube.real.masked_select(self.mask) im = modelVisibilityCube.imag.masked_select(self.mask) # we had trouble returning things as re + 1.0j * im, # but for some reason torch.complex seems to work OK. return torch.complex(re, im)
@property def ground_mask(self) -> torch.Tensor: r""" The boolean mask, arranged in ground format. Returns: torch.boolean : 3D mask cube of shape ``(nchan, npix, npix)`` """ return utils.packed_cube_to_ground_cube(self.mask)
[docs] class Dartboard: r""" A polar coordinate grid relative to a :class:`~mpol.coordinates.GridCoords` object, reminiscent of a dartboard layout. The main utility of this object is to support splitting a dataset along radial and azimuthal bins for k-fold cross validation. Args: coords (GridCoords): an object already instantiated from the GridCoords class. If providing this, cannot provide ``cell_size`` or ``npix``. q_edges (1D numpy array): an array of radial bin edges to set the dartboard cells in :math:`[\mathrm{k}\lambda]`. If ``None``, defaults to 12 log-linearly radial bins stretching from 0 to the :math:`q_\mathrm{max}` represented by ``coords``. phi_edges (1D numpy array): an array of azimuthal bin edges to set the dartboard cells in [radians], over the domain :math:`[0, \pi]`, which is also implicitly mapped to the domain :math:`[-\pi, \pi]` to preserve the Hermitian nature of the visibilities. If ``None``, defaults to 8 equal-spaced azimuthal bins stretched from :math:`0` to :math:`\pi`. """ def __init__( self, coords: GridCoords, q_edges: NDArray[floating[Any]] | None = None, phi_edges: NDArray[floating[Any]] | None = None, ) -> None: self.coords = coords self.nchan = 1 # if phi_edges is not given, we'll instantiate if phi_edges is None: phi_edges = np.linspace(0, np.pi, num=8 + 1) # [radians] elif not all(0 <= edge <= np.pi for edge in phi_edges): raise ValueError("Elements of phi_edges must be between 0 and pi.") if q_edges is None: # set q edges approximately following inspiration from Petry et al. scheme: # https://ui.adsabs.harvard.edu/abs/2020SPIE11449E..1DP/abstract # first two bins set to 7m width # after third bin, bin width increases linearly until it is # 700m at 16km baseline. # From 16m to 16km, bin width goes from 7m to 700m. # --- # We aren't doing *quite* the same thing, # just logspacing with a few linear cells at the start. q_edges = utils.loglinspace(0, self.q_max, N_log=8, M_linear=5) self.q_edges = q_edges self.phi_edges = phi_edges @property def cartesian_qs(self) -> NDArray[floating[Any]]: return self.coords.packed_q_centers_2D @property def cartesian_phis(self) -> NDArray[floating[Any]]: return self.coords.packed_phi_centers_2D @property def q_max(self) -> float: return self.coords.q_max
[docs] def get_polar_histogram( self, qs: NDArray[floating[Any]], phis: NDArray[floating[Any]] ) -> NDArray[floating[Any]]: r""" Calculate a histogram in polar coordinates, using the bin edges defined by ``q_edges`` and ``phi_edges`` during initialization. Data coordinates should include the points for the Hermitian visibilities. Args: qs: 1d array of q values :math:`[\lambda]` phis: 1d array of datapoint azimuth values [radians] (must be the same length as qs) Returns: 2d integer numpy array of cell counts, i.e., how many datapoints fell into each dartboard cell. """ histogram: NDArray # make a polar histogram histogram, *_ = np.histogram2d( qs, phis, bins=[self.q_edges.tolist(), self.phi_edges.tolist()] ) return histogram
[docs] def get_nonzero_cell_indices( self, qs: NDArray[floating[Any]], phis: NDArray[floating[Any]] ) -> NDArray[integer[Any]]: r""" Return a list of the cell indices that contain data points, using the bin edges defined by ``q_edges`` and ``phi_edges`` during initialization. Data coordinates should include the points for the Hermitian visibilities. Args: qs: 1d array of q values :math:`[\lambda]` phis: 1d array of datapoint azimuth values [radians] (must be the same length as qs) Returns: list of cell indices where cell contains at least one datapoint. """ # make a polar histogram histogram = self.get_polar_histogram(qs, phis) indices = np.argwhere(histogram > 0) # [i,j] indexes to go to q, phi return indices
[docs] def build_grid_mask_from_cells( self, cell_index_list: NDArray[integer[Any]] ) -> NDArray[np.bool_]: r""" Create a boolean mask of size ``(npix, npix)`` (in packed format) corresponding to the ``vis_gridded`` and ``weight_gridded`` quantities of the :class:`~mpol.datasets.GriddedDataset` . Args: cell_index_list (list): list or iterable containing [q_cell, phi_cell] index pairs to include in the mask. Returns: (numpy array) 2D boolean mask in packed format. """ mask = np.zeros_like(self.cartesian_qs, dtype="bool") # uses about a Gb..., and this only 256x256 for cell_index in cell_index_list: qi, pi = cell_index q_min, q_max = self.q_edges[qi : qi + 2] p0_min, p0_max = self.phi_edges[pi : pi + 2] # also include Hermitian values p1_min, p1_max = self.phi_edges[pi : pi + 2] - np.pi # whether or not the q and phi values of the coordinate array # fit in the q cell and *either of* the regular or Hermitian phi cell ind = ( (self.cartesian_qs >= q_min) & (self.cartesian_qs < q_max) & ( ((self.cartesian_phis > p0_min) & (self.cartesian_phis <= p0_max)) | ((self.cartesian_phis > p1_min) & (self.cartesian_phis <= p1_max)) ) ) mask[ind] = True return mask