from typing import Optional
import numpy as np
import torch
from mpol import constants
from mpol.datasets import GriddedDataset
def _chi_squared(
model_vis: torch.Tensor, data_vis: torch.Tensor, weight: torch.Tensor
) -> torch.Tensor:
r"""
Computes the :math:`\chi^2` between the complex data :math:`\boldsymbol{V}` and
model :math:`M` visibilities using
.. math::
\chi^2(\boldsymbol{\theta};\,\boldsymbol{V}) =
\sum_i^N w_i |V_i - M(u_i, v_i |\,\boldsymbol{\theta})|^2
where :math:`w_i = 1/\sigma_i^2`. The sum is over all of the provided visibilities.
This function is agnostic as to whether the sum should include the Hermitian
conjugate visibilities, but be aware that the answer returned will be different
between the two cases. We recommend not including the Hermitian conjugates.
Parameters
----------
model_vis : :class:`torch.Tensor` of :class:`torch.complex`
array of the model values representing :math:`\boldsymbol{V}`
data_vis : :class:`torch.Tensor` of :class:`torch.complex`
array of the data values representing :math:`M`
weight : :class:`torch.Tensor`
array of weight values representing :math:`w_i`
Returns
-------
:class:`torch.Tensor`
the :math:`\chi^2` likelihood, summed over all dimensions of input array.
"""
return torch.sum(weight * torch.abs(data_vis - model_vis) ** 2)
[docs]
def r_chi_squared(
model_vis: torch.Tensor, data_vis: torch.Tensor, weight: torch.Tensor
) -> torch.Tensor:
r"""
Calculate the reduced :math:`\chi^2_\mathrm{R}` between the complex data
:math:`\boldsymbol{V}` and model :math:`M` visibilities using
.. math::
\chi^2_\mathrm{R} = \frac{1}{2 N} \chi^2(\boldsymbol{\theta};\,\boldsymbol{V})
where :math:`\chi^2` is evaluated using private function :func:`mpol.losses._chi_squared`.
Data and model visibilities may be any shape as long as all tensors (including
weight) have the same shape. Following `EHT-IV 2019
<https://ui.adsabs.harvard.edu/abs/2019ApJ...875L...4E/abstract>`_, we apply
a prefactor :math:`1/(2 N)`, where :math:`N` is the number of visibilities. The
factor of 2 comes in because we must count real and imaginaries in the
:math:`\chi^2` sum. This loss function will have a minimum value of
:math:`\chi^2_\mathrm{R}(\hat{\boldsymbol{\theta}};\,\boldsymbol{V})
\approx 1` for a well-fit model (regardless of the number of data points), making
it easier to set the prefactor strengths of other regularizers *relative* to this
value.
Note that this function should only be used in an optimization or point estimate
situation `and` where you are not adjusting the weight or the amplitudes of
the data values. If it is used in any situation where uncertainties on parameter values
are determined (such as Markov Chain Monte Carlo), it will return the wrong answer.
This is because the relative scaling of :math:`\chi^2_\mathrm{R}` with respect to
parameter value is incorrect. For those applications, you should use
:meth:`mpol.losses.log_likelihood`.
Parameters
----------
model_vis : :class:`torch.Tensor` of :class:`torch.complex`
array of the model values representing :math:`\boldsymbol{V}`
data_vis : :class:`torch.Tensor` of :class:`torch.complex`
array of the data values representing :math:`M`
weight : :class:`torch.Tensor`
array of weight values representing :math:`w_i`
Returns
-------
:class:`torch.Tensor`
the :math:`\chi^2_\mathrm{R}`, summed over all dimensions of input array.
"""
# If model and data are multidimensional, then flatten them to get full N
N = len(torch.ravel(data_vis))
return 1 / (2 * N) * _chi_squared(model_vis, data_vis, weight)
[docs]
def r_chi_squared_gridded(
modelVisibilityCube: torch.Tensor, griddedDataset: GriddedDataset
) -> torch.Tensor:
r"""
Calculate the reduced :math:`\chi^2_\mathrm{R}` between the complex data
:math:`\boldsymbol{V}` and model :math:`M` visibilities using gridded quantities.
Function will return the same value regardless of whether Hermitian pairs are
included.
Parameters
----------
modelVisibilityCube : :class:`torch.Tensor` of :class:`torch.complex`
torch tensor with shape ``(nchan, npix, npix)`` to be indexed by the
``mask`` from :class:`~mpol.datasets.GriddedDataset`. Assumes tensor is
"pre-packed," as in output from :meth:`mpol.fourier.FourierCube.forward()`.
griddedDataset: :class:`~mpol.datasets.GriddedDataset` object
the gridded dataset, most likely produced from
:meth:`mpol.gridding.DataAverager.to_pytorch_dataset`
Returns
-------
:class:`torch.Tensor`
the :math:`\chi^2_\mathrm{R}` value summed over all input dimensions
"""
model_vis = griddedDataset(modelVisibilityCube)
return r_chi_squared(
model_vis, griddedDataset.vis_indexed, griddedDataset.weight_indexed
)
[docs]
def log_likelihood(
model_vis: torch.Tensor, data_vis: torch.Tensor, weight: torch.Tensor
) -> torch.Tensor:
r"""
Compute the log likelihood function :math:`\ln\mathcal{L}` between the complex data
:math:`\boldsymbol{V}` and model :math:`M` visibilities using
.. math::
\ln \mathcal{L}(\boldsymbol{\theta};\,\boldsymbol{V}) =
- N \ln 2 \pi + \sum_i^N w_i -
\frac{1}{2} \chi^2(\boldsymbol{\theta};\,\boldsymbol{V})
where :math:`N` is the number of complex visibilities and :math:`\chi^2` is
evaluated internally using :func:`mpol.losses._chi_squared`. Note that this expression has
factors of 2 in different places compared to the multivariate Normal you might be
used to seeing because the visibilities are complex-valued. We could alternatively
write
.. math::
\mathcal{L}(\boldsymbol{\theta};\,\boldsymbol{V}) =
\mathcal{L}(\boldsymbol{\theta};\,\Re\{\boldsymbol{V}\}) \times
\mathcal{L}(\boldsymbol{\theta};\,\Im\{\boldsymbol{V}\})
where :math:`\mathcal{L}(\boldsymbol{\theta};\,\Re\{\boldsymbol{V}\})` and
:math:`\mathcal{L}(\boldsymbol{\theta};\,\Im\{\boldsymbol{V}\})` each are the
well-known multivariate Normal for reals.
This function is agnostic as to whether the sum should include the Hermitian
conjugate visibilities, but be aware that the normalization of the answer returned
will be different between the two cases. Inference of the parameter values should
be unaffected. We recommend not including the Hermitian conjugates.
Parameters
----------
model_vis : :class:`torch.Tensor` of :class:`torch.complex128`
array of the model values representing :math:`\boldsymbol{V}`
data_vis : :class:`torch.Tensor` of :class:`torch.complex128`
array of the data values representing :math:`M`
weight : :class:`torch.Tensor`
array of weight values representing :math:`w_i`
Returns
-------
:class:`torch.Tensor`
the :math:`\ln\mathcal{L}` log likelihood, summed over all dimensions
of input array.
"""
# If model and data are multidimensional, then flatten them to get full N
N = len(torch.ravel(data_vis))
weight_term: torch.Tensor = torch.sum(torch.log(weight))
# calculate separately so we can type as np, otherwise mypy thinks
# the expression is Any
first_term: np.float64 = -N * np.log(2 * np.pi)
return first_term + weight_term - 0.5 * _chi_squared(model_vis, data_vis, weight)
[docs]
def log_likelihood_gridded(
modelVisibilityCube: torch.Tensor, griddedDataset: GriddedDataset
) -> torch.Tensor:
r"""
Calculate :math:`\ln\mathcal{L}` (corresponding to
:func:`~mpol.losses.log_likelihood`) using gridded quantities.
Parameters
----------
modelVisibilityCube : :class:`torch.Tensor` of :class:`torch.complex`
torch tensor with shape ``(nchan, npix, npix)`` to be indexed by the
``mask`` from :class:`~mpol.datasets.GriddedDataset`. Assumes tensor is
"pre-packed," as in output from :meth:`mpol.fourier.FourierCube.forward()`.
griddedDataset: :class:`~mpol.datasets.GriddedDataset` object
the gridded dataset, most likely produced from
:meth:`mpol.gridding.DataAverager.to_pytorch_dataset`
Returns
-------
:class:`torch.Tensor`
the :math:`\ln\mathcal{L}` value, summed over all dimensions of input data.
"""
# get the model_visibilities from the dataset
# 1D torch tensor collapsed across cube dimensions, like
# griddedDataset.vis_indexed and griddedDataset.weight_indexed
model_vis = griddedDataset(modelVisibilityCube)
return log_likelihood(
model_vis, griddedDataset.vis_indexed, griddedDataset.weight_indexed
)
[docs]
def neg_log_likelihood_avg(
model_vis: torch.Tensor, data_vis: torch.Tensor, weight: torch.Tensor
) -> torch.Tensor:
r"""
Calculate the average value of the negative log likelihood
.. math::
L = - \frac{1}{2 N} \ln \mathcal{L}(\boldsymbol{\theta};\,\boldsymbol{V})
where :math:`N` is the number of complex visibilities. This loss function is most
useful where you are in an optimization or point estimate
situation `and` where you may adjusting the weight or the amplitudes of
the data values, perhaps via a self-calibration operation.
If you are in any situation where uncertainties on parameter values
are determined (such as Markov Chain Monte Carlo), you should use
:meth:`mpol.losses.log_likelihood`.
Parameters
----------
model_vis : :class:`torch.Tensor` of :class:`torch.complex`
array of the model values representing :math:`\boldsymbol{V}`
data_vis : :class:`torch.Tensor` of :class:`torch.complex`
array of the data values representing :math:`M`
weight : :class:`torch.Tensor`
array of weight values representing :math:`w_i`
Returns
-------
:class:`torch.Tensor`
the average of the negative log likelihood, summed over all dimensions of
input array.
"""
N = len(torch.ravel(data_vis)) # number of complex visibilities
ll = log_likelihood(model_vis, data_vis, weight)
# factor of 2 is because of complex calculation
return -ll / (2 * N)
[docs]
def entropy(
cube: torch.Tensor, prior_intensity: torch.Tensor, tot_flux: float = 10
) -> torch.Tensor:
r"""
Calculate the entropy loss of a set of pixels following the definition in
`EHT-IV 2019 <https://ui.adsabs.harvard.edu/abs/2019ApJ...875L...4E/abstract>`_.
.. math::
L = \frac{1}{\zeta} \sum_i I_i \; \ln \frac{I_i}{p_i}
Parameters
----------
cube : :class:`torch.Tensor`
pixel values must be positive :math:`I_i > 0` for all :math:`i`
prior_intensity : :class:`torch.Tensor`
the prior value :math:`p` to calculate entropy against. Tensors of any shape
are allowed so long as they will broadcast to the shape of the cube under
division (`/`).
tot_flux : float
a fixed normalization factor; the user-defined target total flux density, in
units of Jy.
Returns
-------
:class:`torch.Tensor`
entropy loss
"""
# check to make sure image is positive, otherwise raise an error
assert (cube >= 0.0).all(), "image cube contained negative pixel values"
assert prior_intensity > 0, "image prior intensity must be positive"
assert tot_flux > 0, "target total flux must be positive"
return (1 / tot_flux) * torch.sum(cube * torch.log(cube / prior_intensity))
[docs]
def TV_image(sky_cube: torch.Tensor, epsilon: float = 1e-10) -> torch.Tensor:
r"""
Calculate the total variation (TV) loss in the image dimension (R.A. and DEC).
Following the definition in `EHT-IV 2019
<https://ui.adsabs.harvard.edu/abs/2019ApJ...875L...4E/abstract>`_ Promotes the
image to be piecewise smooth and the gradient of the image to be sparse.
.. math::
L = \sum_{l,m,v} \sqrt{(I_{l + 1, m, v} - I_{l,m,v})^2 +
(I_{l, m+1, v} - I_{l, m, v})^2 + \epsilon}
Parameters
----------
sky_cube: 3D :class:`torch.Tensor`
the image cube array :math:`I_{lmv}`, where :math:`l`
is R.A. in :math:`ndim=3`, :math:`m` is DEC in :math:`ndim=2`, and
:math:`v` is the channel (velocity or frequency) dimension in
:math:`ndim=1`. Should be in sky format representation.
epsilon : float
a softening parameter in units of [:math:`\mathrm{Jy}/\mathrm{arcsec}^2`].
Any pixel-to-pixel variations within each image North-South or East-West
slice greater than this parameter will incur a significant penalty.
Returns
-------
:class:`torch.Tensor`
total variation loss
"""
diff_ll = torch.diff(sky_cube[:, 0:-1, :], dim=2)
diff_mm = torch.diff(sky_cube[:, :, 0:-1], dim=1)
loss = torch.sqrt(diff_ll**2 + diff_mm**2 + epsilon).sum()
return loss
[docs]
def TV_channel(cube: torch.Tensor, epsilon: float = 1e-10) -> torch.Tensor:
r"""
Calculate the total variation (TV) loss in the channel (first) dimension.
Following the definition in `EHT-IV 2019
<https://ui.adsabs.harvard.edu/abs/2019ApJ...875L...4E/abstract>`_, calculate
.. math::
L = \sum_{l,m,v} \sqrt{(I_{l, m, v + 1} - I_{l,m,v})^2 + \epsilon}
Parameters
----------
cube: :class:`torch.Tensor`
the image cube array :math:`I_{lmv}`
epsilon: float
a softening parameter in units of [:math:`\mathrm{Jy}/\mathrm{arcsec}^2`].
Any channel-to-channel pixel variations greater than this parameter will incur
a significant penalty.
Returns
-------
:class:`torch.Tensor`
total variation loss
"""
# calculate the difference between the n+1 cube and the n cube
diff_vel = cube[1:] - cube[0:-1]
loss = torch.sum(torch.sqrt(diff_vel**2 + epsilon))
return loss
[docs]
def TSV(sky_cube: torch.Tensor) -> torch.Tensor:
r"""
Calculate the total square variation (TSV) loss in the image dimension
(R.A. and DEC). Following the definition in `EHT-IV 2019
<https://ui.adsabs.harvard.edu/abs/2019ApJ...875L...4E/abstract>`_ Promotes the
image to be edge smoothed which may be a better reoresentation of the truth image
`K. Kuramochi et al 2018
<https://ui.adsabs.harvard.edu/abs/2018ApJ...858...56K/abstract>`_.
.. math::
L = \sum_{l,m,v} (I_{l + 1, m, v} - I_{l,m,v})^2 +
(I_{l, m+1, v} - I_{l, m, v})^2
Parameters
----------
sky_cube :class:`torch.Tensor`
the image cube array :math:`I_{lmv}`, where :math:`l`
is R.A. in :math:`ndim=3`, :math:`m` is DEC in :math:`ndim=2`, and
:math:`v` is the channel (velocity or frequency) dimension in
:math:`ndim=1`. Should be in sky format representation.
Returns
-------
:class:`torch.Tensor`
total square variation loss
"""
diff_ll = torch.diff(sky_cube[:, 0:-1, :], dim=2)
diff_mm = torch.diff(sky_cube[:, :, 0:-1], dim=1)
loss = torch.sum(diff_ll**2 + diff_mm**2)
return loss
[docs]
def sparsity(cube: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
r"""
Enforce a sparsity prior on the image cube using the :math:`L_1` norm. Optionally
provide a boolean mask to apply the prior to only the ``True`` locations. For
example, you might want this mask to be ``True`` for background regions.
The sparsity loss calculated as
.. math::
L = \sum_i | I_i |
Parameters
----------
cube : :class:`torch.Tensor`
the image cube array :math:`I_{lmv}`
mask : :class:`torch.Tensor` of :class:`torch.bool`
tensor array the same shape as ``cube``. The sparsity prior
will be applied to those pixels where the mask is ``True``. Default is
to apply prior to all pixels.
Returns
-------
:class:`torch.Tensor`
sparsity loss calculated where ``mask == True``
"""
if mask is not None:
loss = torch.sum(torch.abs(cube.masked_select(mask)))
else:
loss = torch.sum(torch.abs(cube))
return loss
[docs]
def UV_sparsity(
vis: torch.Tensor, qs: torch.Tensor, q_max: torch.Tensor
) -> torch.Tensor:
r"""
Enforce a sparsity prior for all :math:`q = \sqrt{u^2 + v^2}` points larger than
:math:`q_\mathrm{max}`.
Parameters
----------
vis : :class:`torch.Tensor` of :class:`torch.complex128`
visibility cube of (nchan, npix, npix//2 +1, 2)
qs : :class:`torch.Tensor` of :class:`torch.float64`
array corresponding to visibility coordinates. Dimensionality of
(npix, npix//2)
q_max : float
maximum radial baseline
Returns
-------
:class:`torch.Tensor`
UV sparsity loss above :math:`q_\mathrm{max}`
"""
# make a mask, then send it to the device (in case we're using a GPU)
mask = torch.tensor((qs > q_max), dtype=torch.bool).to(vis.device)
vis_re = vis[:, :, :, 0]
vis_im = vis[:, :, :, 1]
# broadcast mask to the same shape as vis
mask = mask.unsqueeze(0)
loss = torch.sum(torch.abs(vis_re.masked_select(mask))) + torch.sum(
torch.abs(vis_im.masked_select(mask))
)
return loss
[docs]
def PSD(qs: torch.Tensor, psd: torch.Tensor, l: torch.Tensor) -> torch.Tensor:
r"""
Apply a loss function corresponding to the power spectral density using a Gaussian
process kernel.
Assumes an image plane kernel of
.. math::
k(r) = \exp(-\frac{r^2}{2 \ell^2})
The corresponding power spectral density is
.. math::
P(q) = (2 \pi \ell^2) \exp(- 2 \pi^2 \ell^2 q^2)
Parameters
----------
qs : :class:`torch.Tensor`
the radial UV coordinate (in :math:`\lambda`)
psd : :class:`torch.Tensor`
the power spectral density cube
l : :class:`torch.Tensor`
the correlation length in the image plane (in arcsec)
Returns
-------
:class:`torch.Tensor`
the loss calculated using the power spectral density
"""
l_rad = l * constants.arcsec # radians
# calculate the expected power spectral density
expected_PSD = (
2 * np.pi * l_rad**2 * torch.exp(-2 * np.pi**2 * l_rad**2 * qs**2)
)
# evaluate the chi^2 for the PSD, making sure it broadcasts across all channels
loss = torch.sum(psd / expected_PSD)
return loss
[docs]
def edge_clamp(cube: torch.Tensor) -> torch.Tensor:
r"""
Promote all pixels at the edge of the image to be zero using an :math:`L_2` norm.
Parameters
----------
cube: :class:`torch.Tensor`
the image cube array :math:`I_{lmv}`
Returns
-------
:class:`torch.Tensor`
edge loss
"""
# find edge pixels
# all channels
# pixel edges
bt_edges = cube[:, (0, -1)]
lr_edges = cube[:, :, (0, -1)]
loss = torch.sum(bt_edges**2) + torch.sum(lr_edges**2)
return loss