Hide code cell content
%run notebook_setup

Intro to RML with MPoL#

In this tutorial, we’ll construct an optimization loop demonstrating how we can use MPoL to synthesize a basic image. We’ll continue with the dataset described in the Gridding and Diagnostic Images tutorial.

Gridding recap#

Let’s set up the DataAverager and GridCoords objects as before

import matplotlib.pyplot as plt
import numpy as np
import torch
from astropy.utils.data import download_file
from IPython.display import SVG, display
from mpol import coordinates, fourier, gridding, losses, precomposed, utils
from mpol.__init__ import zenodo_record
# load the mock dataset of the ALMA logo
fname = download_file(
    f"https://zenodo.org/record/{zenodo_record}/files/logo_cube.noise.npz",
    cache=True,
    show_progress=True,
    pkgname="mpol",
)

# this is a multi-channel dataset... for demonstration purposes we'll use
# only the central, single channel
chan = 4
d = np.load(fname)
uu = d["uu"][chan]
vv = d["vv"][chan]
weight = d["weight"][chan]
data = d["data"][chan]
data_re = np.real(data)
data_im = np.imag(data)
# define the image dimensions, as in the previous tutorial
coords = coordinates.GridCoords(cell_size=0.005, npix=800)
averager = gridding.DataAverager(
    coords=coords,
    uu=uu,
    vv=vv,
    weight=weight,
    data_re=data_re,
    data_im=data_im,
)

The PyTorch dataset#

Now we will export the visibilities to a PyTorch dataset to use in the imaging loop. The mpol.gridding.DataAverager.to_pytorch_dataset() routine performs a weighted average all of the visibilities to the Fourier grid cells and exports the visibilities to cube-like PyTorch tensors. To keep things simple in this tutorial, we are only using a single channel. But you could just as easily export a multi-channel dataset. Note that the to_pytorch_dataset() routine automatically checks the visibility scatter and raises a RuntimeError if the empirically-estimated scatter exceeds that expected from the provided dataset weights. For more information, see the end of the Gridding and Diagnostic Images Tutorial.

In the following tutorial on the NuFFT, we’ll explore an alternate MPoL layer that avoids gridding the visibilities all together. This approach may be more accurate for certain applications, but is usually slower to execute than the gridding approach described in this tutorial. For that reason, we recommend starting with the default gridding approach and only moving to the NuFFT layers once you are reasonably happy with the images you are getting.

dset = averager.to_pytorch_dataset()
print("this dataset has {:} channel".format(dset.nchan))
this dataset has 1 channel

Building an image model#

MPoL provides “modules” to build and optimize complex imaging workflows, not dissimilar to how a deep neural network might be constructed. We’ve bundled the most common modules for imaging together in a mpol.precomposed.SimpleNet meta-module, which we’ll use here.

This diagram shows how the primitive modules, like mpol.images.BaseCube, mpol.images.ImageCube, etc… are connected together to form mpol.precomposed.SimpleNet. In this workflow, the pixel values of the mpol.images.BaseCube are the core model parameters representing the image. More information about all of these components is available in the API documentation.

graph TD subgraph SimpleNet bc(BaseCube) --> HannConvCube HannConvCube --> ImageCube ImageCube --> FourierLayer end FourierLayer --> il([Loss]) ad[[Dataset]] --> il([Loss])

It isn’t necessary to construct a meta-module to do RML imaging with MPoL, though it often helps organize your code. If we so desired, we could connect the individual modules together ourselves ourselves following the SimpleNet source code as an example (mpol.precomposed.SimpleNet) and swap in/out modules as we saw fit.

We then initialize SimpleNet with the relevant information

rml = precomposed.SimpleNet(coords=coords, nchan=dset.nchan)

Breaking down the training loop#

Our goal for the rest of the tutorial is to set up a loop that will

  1. evaluate the current model against a loss function

  2. calculate the gradients of the loss w.r.t. the model

  3. advance the model parameters in the direction to minimize the loss function

We’ll start by creating the optimizer

optimizer = torch.optim.SGD(rml.parameters(), lr=3e4)

The role of the optimizer is to advance the parameters (in this case, the pixel values of the mpol.images.BaseCube using the gradient of the loss function with respect to those parameters. PyTorch has many different optimizers available, and it is worthwhile to try out some of the different ones. Stochastic Gradient Descent (SGD) is one of the simplest, so we’ll start here. The lr parameter is the ‘learning rate,’ or how ambitious the optimizer should be in taking descent steps. Tuning this requires a bit of trial and error: you want the learning rate to be small enough so that the algorithm doesn’t diverge but large enough so that the optimization completes in a reasonable amount of time.

Loss functions#

In the parlance of the machine learning community, one defines “loss” functions comparing models to data. For regularized maximum likelihood imaging, the most fundamental loss function we’ll use is the mpol.losses.loss_fn() or the \(\chi^2\) value comparing the model visibilities to the data visibilities. For this introductory tutorial, we’ll use only the data likelihood loss function to start, but you should know that because imaging is an ill-defined inverse problem, this is not a sufficient constraint by itself. In later tutorials, we will apply regularization to narrow the set of possible images towards ones that we believe are more realistic. The mpol.losses module contains several loss functions currently popular in the literature, so you can experiment to see which best suits your application.

Gradient descent#

Let’s walk through how we calculate a loss value and optimize the parameters. To start, let’s examine the parameters of the model

rml.state_dict()
OrderedDict([('bcube.base_cube',
              tensor([[[0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
                       [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
                       [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
                       ...,
                       [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
                       [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
                       [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500]]],
                     dtype=torch.float64)),
             ('conv_layer.m.weight',
              tensor([[[[0.0625, 0.1250, 0.0625],
                        [0.1250, 0.2500, 0.1250],
                        [0.0625, 0.1250, 0.0625]]]], dtype=torch.float64)),
             ('conv_layer.m.bias', tensor([0.], dtype=torch.float64))])

These are the default values that were used to initialize the mpol.images.BaseCube component of the mpol.precomposed.SimpleNet.

For demonstration purposes, lets access and plot the base cube with matplotlib. In a normal workflow you probably won’t need to do this, but to access the basecube in sky orientation, we do

bcube_pytorch = utils.packed_cube_to_sky_cube(rml.bcube.base_cube)

bcube is still a PyTorch tensor, but matplotlib requires numpy arrays. To convert back, we need to first “detach” the computational graph from the PyTorch tensor (used to propagate gradients) and then call the numpy conversion routine.

bcube_numpy = bcube_pytorch.detach().numpy()
print(bcube_numpy.shape)
(1, 800, 800)

lastly, we remove the channel dimension to plot the 2D image using np.squeeze

fig, ax = plt.subplots(nrows=1)
im = ax.imshow(
    np.squeeze(bcube_numpy),
    origin="lower",
    interpolation="none",
    extent=rml.icube.coords.img_ext,
)
plt.xlabel(r"$\Delta \alpha \cos \delta$ [${}^{\prime\prime}$]")
plt.ylabel(r"$\Delta \delta$ [${}^{\prime\prime}$]")
plt.colorbar(im)
<matplotlib.colorbar.Colorbar at 0x7fb514019d60>
../_images/a22504f6df41e5156e9d931c07df951232ae52bc8134c7f62371b347bbeabf5e.png

A blank image is not that exciting, but hopefully this demonstrates the state of the parameters at the start of optimization.

Because we’ll want to compute a clean set of gradient values in a later step, we “zero out” any gradients attached to the tensor components so that they aren’t counted twice.

rml.zero_grad()

Most modules in MPoL are designed to work in a “feed forward” manner, which means base parameters are processed through the network to predict model visibilites for comparison with data. We can calculate the full visibility cube corresponding to the current pixel values of the mpol.images.BaseCube.

vis = rml()
print(vis)
tensor([[[ 1.1481e+01+0.0000e+00j,  7.1800e-03+2.8196e-05j,
          -7.1797e-03-5.6390e-05j,  ...,
           7.1791e-03-8.4581e-05j, -7.1797e-03+5.6390e-05j,
           7.1800e-03-2.8196e-05j],
         [ 7.1800e-03+2.8196e-05j,  4.4902e-06+3.5266e-08j,
          -4.4899e-06-5.2898e-08j,  ...,
           4.4899e-06-3.5264e-08j, -4.4902e-06+1.7633e-08j,
           4.4903e-06+1.9516e-22j],
         [-7.1797e-03-5.6390e-05j, -4.4899e-06-5.2898e-08j,
           4.4895e-06+7.0527e-08j,  ...,
          -4.4899e-06+1.7632e-08j,  4.4901e-06+6.9389e-22j,
          -4.4902e-06-1.7633e-08j],
         ...,
         [ 7.1791e-03-8.4581e-05j,  4.4899e-06-3.5264e-08j,
          -4.4899e-06+1.7632e-08j,  ...,
           4.4885e-06-1.0578e-07j, -4.4891e-06+8.8154e-08j,
           4.4895e-06-7.0526e-08j],
         [-7.1797e-03+5.6390e-05j, -4.4902e-06+1.7633e-08j,
           4.4901e-06-8.6736e-23j,  ...,
          -4.4891e-06+8.8154e-08j,  4.4895e-06-7.0527e-08j,
          -4.4899e-06+5.2898e-08j],
         [ 7.1800e-03-2.8196e-05j,  4.4903e-06-1.7347e-22j,
          -4.4902e-06-1.7633e-08j,  ...,
           4.4895e-06-7.0526e-08j, -4.4899e-06+5.2898e-08j,
           4.4902e-06-3.5266e-08j]]], dtype=torch.complex128,
       grad_fn=<MulBackward0>)

Of course, these aren’t that exciting since they just reflect the constant value image.

But, exciting things are about to happen! We can calculate the loss between these model visibilities and the data

# calculate a loss
loss = losses.nll_gridded(vis, dset)
print(loss.item())
2.5873771357650925

and then we can calculate the gradient of the loss function with respect to the parameters

loss.backward()

We can even visualize what the gradient of the mpol.images.BaseCube looks like (using a similar .detach() call as before)

fig, ax = plt.subplots(nrows=1)
im = ax.imshow(
    np.squeeze(
        utils.packed_cube_to_sky_cube(rml.bcube.base_cube.grad).detach().numpy()
    ),
    origin="lower",
    interpolation="none",
    extent=rml.icube.coords.img_ext,
)
plt.xlabel(r"$\Delta \alpha \cos \delta$ [${}^{\prime\prime}$]")
plt.ylabel(r"$\Delta \delta$ [${}^{\prime\prime}$]")
plt.colorbar(im)
<matplotlib.colorbar.Colorbar at 0x7fb513ecc1c0>
../_images/cc08ac013e4d1302828d6182b0cbdc8e1bd0b754c201d23e801394ac99050b6e.png

The gradient image points in the direction of lower loss values. So the final step is to add the gradient image to the base image in order to advance base parameters in the direction of the minimum loss value. This process is called gradient descent, and can be extremely useful for optimizing large dimensional parameter spaces (like images). The optimizer carries out the addition of the gradient

optimizer.step()

We can see that the parameter values have changed

rml.state_dict()
OrderedDict([('bcube.base_cube',
              tensor([[[0.3193, 0.3336, 0.3468,  ..., 0.2729, 0.2887, 0.3042],
                       [0.2943, 0.3086, 0.3217,  ..., 0.2477, 0.2635, 0.2791],
                       [0.2693, 0.2836, 0.2967,  ..., 0.2226, 0.2385, 0.2541],
                       ...,
                       [0.3939, 0.4080, 0.4212,  ..., 0.3484, 0.3639, 0.3791],
                       [0.3693, 0.3835, 0.3967,  ..., 0.3235, 0.3390, 0.3544],
                       [0.3444, 0.3587, 0.3718,  ..., 0.2982, 0.3139, 0.3294]]],
                     dtype=torch.float64)),
             ('conv_layer.m.weight',
              tensor([[[[0.0625, 0.1250, 0.0625],
                        [0.1250, 0.2500, 0.1250],
                        [0.0625, 0.1250, 0.0625]]]], dtype=torch.float64)),
             ('conv_layer.m.bias', tensor([0.], dtype=torch.float64))])

as has the base image

fig, ax = plt.subplots(nrows=1)
im = ax.imshow(
    np.squeeze(utils.packed_cube_to_sky_cube(rml.bcube.base_cube).detach().numpy()),
    origin="lower",
    interpolation="none",
    extent=rml.icube.coords.img_ext,
)
plt.xlabel(r"$\Delta \alpha \cos \delta$ [${}^{\prime\prime}$]")
plt.ylabel(r"$\Delta \delta$ [${}^{\prime\prime}$]")
plt.colorbar(im)
<matplotlib.colorbar.Colorbar at 0x7fb5106620d0>
../_images/87d3f9f450266b13d76a233cc18fdbd414d36639350fafb5d5dce309c3cee5c5.png

Iterating the training Loop#

Now that we’ve covered how to use gradient descent to optimize a set of image parameters, let’s wrap these steps into a training loop and iterate a few hundred times to converge to a final product.

In addition to the steps just outlined, we’ll also track the loss values as we optimize.

%%time

loss_tracker = []

for i in range(300):
    rml.zero_grad()

    # get the predicted model
    vis = rml()

    # calculate a loss
    loss = losses.nll_gridded(vis, dset)

    loss_tracker.append(loss.item())

    # calculate gradients of parameters
    loss.backward()

    # update the model parameters
    optimizer.step()
CPU times: user 17.8 s, sys: 2.62 s, total: 20.4 s
Wall time: 10.7 s
fig, ax = plt.subplots(nrows=1)
ax.plot(loss_tracker)
ax.set_xlabel("iteration")
ax.set_ylabel("loss")
Text(0, 0.5, 'loss')
../_images/dfd99a02c7ef3aa542975d05bb414d5005eca86250a3c91ee965c35270a005e5.png

and we see that we’ve reasonably converged to a set of parameters without much further improvement in the loss value.

All of the method presented here can be sped up using GPU acceleration on certain Nvidia GPUs. To learn more about this, please see the GPU Setup Tutorial.

Visualizing the image#

Let’s visualize the final image product. The bounds for matplotlib.pyplot.imshow are available in the img_ext parameter.

# let's see what one channel of the image looks like
fig, ax = plt.subplots(nrows=1)
img_cube = rml.icube.sky_cube.detach().numpy()
im = ax.imshow(
    np.squeeze(img_cube),
    origin="lower",
    interpolation="none",
    extent=rml.icube.coords.img_ext,
)
plt.xlabel(r"$\Delta \alpha \cos \delta$ [${}^{\prime\prime}$]")
plt.ylabel(r"$\Delta \delta$ [${}^{\prime\prime}$]")
plt.colorbar(im, label=r"Jy/$\mathrm{arcsec}^2$")
<matplotlib.colorbar.Colorbar at 0x7fb5121ca340>
../_images/d6504f8658d9cebd116783bd607d36f77ef3f2041afe8be9145eaa160a5afcd9.png

And there you have it, an image optimized to fit the data. To be honest, the results aren’t great—that’s because we’ve used minimal regularization in the form of the functional basis set we chose that automatically enforced image positivity (see the mpol.images.BaseCube documentation). Otherwise, our only contribution to the loss function is the data likelihood. This means it’s easy for the lower signal-to-noise visibilities at longer baselines to dominate the image appearance (not unlike CLEAN images made with “uniform” weighting) and there is high “noise” in the image.

Visualizing the residuals#

We started this tutorial with a collection of visibility data—the complex-valued samples of the Fourier transform of the true sky brightness. We used the forward-modeling capabilities of MPoL and Pytorch to propose image-plane models of the true sky brightness and then used the Fast Fourier Transform to convert these into model visibilities. These model visibilities were compared against the data visibilities using the negative log likelihood loss function and (optionally) additional losses calculated using regularizers. We then used the Pytorch autodifferentiation machinery to calculate derivatives of this loss function space and evolved the model of the sky brightness until the loss function was minimized. At the end of this process, we are left with an image-plane model that produces the minimum loss function values when compared against the data visibilities and any regularizers.

It nearly always worthwhile to visualize the residuals, defined as

\[ \mathrm{residuals} = \mathrm{data} - \mathrm{model} \]

For speed reasons, the mpol.precomposed.SimpleNet does not work with the original data visibilities directly, but instead uses an averaged version of them in GriddedDataset. To calculate model visibilities corresponding to the original \(u,v\) points of the dataset, we will need to use the mpol.fourier.NuFFT layer. More detail on this object is in the Loose Visibilities tutorial, but basically we instantiate the NuFFT layer relative to some image dimensions and \(u,v\) locations

nufft = fourier.NuFFT(coords=coords, nchan=dset.nchan)

and then we can calculate model visibilities corresponding to some model image (in this case, our optimal image). Since mpol.fourier.NuFFT.forward() returns a Pytorch tensor, we’ll need to detach it and convert it to numpy. We’ll also remove the channel dimension.

# note the NuFFT expects a "packed image cube," as output from ImageCube.forward()
vis_model = nufft(rml.icube(), uu, vv)
# convert from Pytorch to numpy, remove channel dimension
vis_model = np.squeeze(vis_model.detach().numpy())

and then use these model visibilities to calculate residual visibilities

vis_resid = data - vis_model

There are many ways we could visualize these residuals. The simplest type of visualization is to examine the scatter of the residuals relative to their expected standard deviation (given by their thermal weights).

sigmas = np.sqrt(1/weight)
resid_real = np.real(vis_resid) / sigmas
resid_imag = np.imag(vis_resid) / sigmas


def gaussian(x, sigma=1):
    r"""
    Evaluate a reference Gaussian as a function of :math:`x`

    Args:
        x (float): location to evaluate Gaussian

    The Gaussian is defined as

    .. math::

        f(x) = \frac{1}{\sqrt{2 \pi}} \exp \left ( -\frac{x^2}{2}\right )

    Returns:
        Gaussian function evaluated at :math:`x`
    """
    return 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-0.5 * (x / sigma) ** 2)


fig, ax = plt.subplots(ncols=2, figsize=(5, 2.5))

xs = np.linspace(-5, 5)

ax[0].hist(resid_real, density=True, bins=40)
ax[1].hist(resid_imag, density=True, bins=40)

ax[0].set_xlabel(
    r"$\mathrm{Re} \{ V_\mathrm{data} - V_\mathrm{model} \} / \sigma$"
)

ax[1].set_xlabel(
    r"$\mathrm{Im} \{ V_\mathrm{data} - V_\mathrm{model} \} / \sigma$"
)

for a in ax.flatten():
    a.plot(xs, gaussian(xs))

fig.subplots_adjust(wspace=0.3)
../_images/31bb2de4bf4f1f56677238704c6f755214094a454189136faa7a28cd965af940.png

If our model is good, we would expect it to fit the signal contained in the data and the residuals contain only noise. As far as this plot is concerned, it appears as though we might have achieved this. However, most radio interferometric visibilities will be noise-dominated (even for high-signal to noise sources), and so a balanced residual scatter is easier to achieve than you might think. If the residuals were significantly over- or under-dispersed relative to their theoretical standard deviation, though, we should examine the calibration of the data (in particular, the calculation of the visibility weights).

A more useful diagnostic is to image the residual visibilities. We can do that by treating them as a new “dataset” and using the diagnostic imaging capabilities of mpol.gridding.DirtyImager. Here we’ll choose a Briggs “robust” value of 0.0, but you are free to image the residual visibilities with whatever weighting makes sense for your science application. When doing this type of residual visualization, we recommend specifying the units of the dirty image using unit="Jy/arcsec^2", which are the natural units of MPoL images.

imager = gridding.DirtyImager(
    coords=coords,
    uu=uu,
    vv=vv,
    weight=weight,
    data_re=np.real(vis_resid),
    data_im=np.imag(vis_resid),
)

img, beam = imager.get_dirty_image(weighting="briggs", robust=0.0, unit="Jy/arcsec^2")

fig, ax = plt.subplots(nrows=1)
im = ax.imshow(
    img[0],
    origin="lower",
    interpolation="none",
    extent=rml.icube.coords.img_ext,
)
plt.xlabel(r"$\Delta \alpha \cos \delta$ [${}^{\prime\prime}$]")
plt.ylabel(r"$\Delta \delta$ [${}^{\prime\prime}$]")
plt.colorbar(im, label=r"Jy/$\mathrm{arcsec}^2$");
../_images/c1271afe83dccadd15027693d521e8846a2f8097ed3c3d6c633a7143b7d86450.png

We see that the residual image appears to be mostly noise, though with some low-frequency spatial structure to the image. This suggests that our model has done a reasonable job of capturing the signal in the data. But, there is still obvious room for improvement in regularizing some of the noiser features. In the following tutorials we’ll examine how to set up additional regularizer terms that will yield more desireable image characteristics.

Hopefully this tutorial has demonstrated the core concepts of synthesizing an image with MPoL. If you have any questions about the process, please feel free to reach out and start a GitHub discussion. If you spot a bug or have an idea to improve these tutorials, please raise a GitHub issue or better yet submit a pull request.