Source code for mpol.crossval

from __future__ import annotations

import copy
import logging
from typing import Any

import numpy as np
import torch
from numpy import floating
from numpy.typing import NDArray

from mpol.datasets import Dartboard, GriddedDataset

# from mpol.training import TrainTest, train_to_dirty_image
# from mpol.training import TrainTest, train_to_dirty_image


# class CrossValidate:
#     r"""
#     Utilities to run a cross-validation loop (implicitly running a training
#     optimization loop), in order to compare MPoL models with different
#     hyperparameter values

#     Parameters
#     ----------
#     coords : `mpol.coordinates.GridCoords` object
#         Instance of the `mpol.coordinates.GridCoords` class.
#     imager : `mpol.gridding.DirtyImager` object
#         Instance of the `mpol.gridding.DirtyImager` class.
#     learn_rate : float, default=0.3
#         Initial learning rate
#     regularizers : nested dict, default={}
#         Dictionary of image regularizers to use. For each, a dict of the
#         strength ('lambda', float), whether to guess an initial value for lambda
#         ('guess', bool), and other quantities needed to compute their loss term.
#         Example:
#         {"sparsity":{"lambda":1e-3, "guess":False},
#         "entropy": {"lambda":1e-3, "guess":True, "prior_intensity":1e-10}
#         }
#     epochs : int, default=10000
#         Number of training iterations
#     convergence_tol : float, default=1e-5
#         Tolerance for training iteration stopping criterion as assessed by
#         loss function (suggested <= 1e-3)
#     schedule_factor : float, default=0.995
#         For the `torch.optim.lr_scheduler.ReduceLROnPlateau` scheduler, factor
#         to which the learning rate is reduced when learning rate stops decreasing
#     start_dirty_image : bool, default=False
#         Whether to start the RML optimization loop by initializing the model
#         image to a dirty image of the observed data. If False, the optimization
#         loop will start with a blank image.
#     train_diag_step : int, default=None
#         Interval at which training diagnostics are output. If None, no
#         diagnostics will be generated.
#     kfolds : int, default=5
#         Number of k-folds to use in cross-validation
#     split_method : str, default='dartboard'
#         Method to split full dataset into train/test subsets
#     dartboard_q_edges, dartboard_phi_edges : list of float, default=None, 
#   unit=[klambda]
#         Radial and azimuthal bin edges of the cells used to split the dataset
#         if `split_method`==`dartboard` (see `datasets.Dartboard`)
#     split_diag_fig : bool, default=False
#         Whether to generate a diagnostic figure of dataset splitting into
#         train/test sets.
#     store_cv_diagnostics : bool, default=False
#         Whether to store diagnostics of the cross-validation loop.
#     save_prefix : str, default=None
#         Prefix (path) used for saved figure names. If None, figures won't be
#         saved
#     verbose : bool, default=True
#         Whether to print notification messages.
#     device : torch.device, default=None
#         Which hardware device to perform operations on (e.g., 'cuda:0').
#         'None' defaults to current device.
#     seed : int, default=None
#         Seed for random number generator used in splitting data
#     """

#     def __init__(
#         self,
#         coords,
#         imager,
#         learn_rate=0.3,
#         regularizers={},
#         epochs=10000,
#         convergence_tol=1e-5,
#         schedule_factor=0.995,
#         start_dirty_image=False,
#         train_diag_step=None,
#         kfolds=5,
#         split_method="dartboard",
#         dartboard_q_edges=None,
#         dartboard_phi_edges=None,
#         split_diag_fig=False,
#         store_cv_diagnostics=False,
#         save_prefix=None,
#         verbose=True,
#         device=None,
#         seed=None,
#     ):
#         self._coords = coords
#         self._imager = imager
#         self._learn_rate = learn_rate
#         self._regularizers = regularizers
#         self._epochs = epochs
#         self._convergence_tol = convergence_tol
#         self._schedule_factor = schedule_factor
#         self._start_dirty_image = start_dirty_image
#         self._train_diag_step = train_diag_step
#         self._kfolds = kfolds
#         self._split_method = split_method
#         self._dartboard_q_edges = dartboard_q_edges
#         self._dartboard_phi_edges = dartboard_phi_edges
#         self._split_diag_fig = split_diag_fig
#         self._store_cv_diagnostics = store_cv_diagnostics
#         self._save_prefix = save_prefix
#         self._verbose = verbose
#         self._device = device
#         self._seed = seed

#         self._split_figure = None

#         # used to collect objects across all kfolds
#         self._diagnostics = None

#         if self._verbose:
#             logging.info("\nCross-validation")

#     def split_dataset(self, dataset):
#         r"""
#         Split a dataset into training and test subsets.

#         Parameters
#         ----------
#         dataset : PyTorch dataset object
#             Instance of the `mpol.datasets.GriddedDataset` class

#         Returns
#         -------
#         split_iterator : iterator returning tuple
#             Iterator that provides a (train, test) pair of
#             :class:`~mpol.datasets.GriddedDataset` for each k-fold
#         """
#         if self._split_method == "random_cell":
#             split_iterator = RandomCellSplitGridded(
#                 dataset=dataset, k=self._kfolds, seed=self._seed
#             )

#         elif self._split_method == "dartboard":
#             if self._dartboard_q_edges is None:
#                 # create a radial partition for the dataset.
#                 # this is the same as the default q_edges in `datasets.Dartboard`,
#                 # except that the max baseline is set by (a padding factor times)
#                 # the maximum baseline in the dataset, rather than by the largest
#                 # baseline in the Fourier plane grid `coords.q_max` (which can
#                 # often be a factor of >~2 larger than the longest baseline in
#                 # the dataset).
#                 stacked_mask = torch.any(dataset.mask, axis=0)
#                 stacked_mask = stacked_mask.to("cpu")  # TODO: remove
#                 qs = dataset.coords.packed_q_centers_2D[stacked_mask]
#                 pad_factor = 1.1
#                 q_edges = loglinspace(0, qs.max() * pad_factor, N_log=8, M_linear=5)

#             dartboard = Dartboard(
#                 coords=self._coords,
#                 q_edges=q_edges,
#                 phi_edges=self._dartboard_phi_edges,
#             )

#             # use 'dartboard' to split full dataset into train/test subsets
#             split_iterator = DartboardSplitGridded(
#                 dataset, k=self._kfolds, dartboard=dartboard, seed=self._seed
#             )

#             if self._verbose:
#                 logging.info(
#                     f"  Max baseline in Fourier grid {self._coords.q_max:.0f} klambda"
#                 )
#                 logging.info(
#                     f"    Dartboard: baseline bin edges {[round(x, 1) for x in dartboard.q_edges.tolist()]} klambda"
#                 )

#         else:
#             supported_methods = ["dartboard", "random_cell"]
#             raise ValueError(
#                 "'split_method' {} must be one of "
#                 "{}".format(self._split_method, supported_methods)
#             )

#         return split_iterator

#     def run_crossval(self, dataset):
#         r"""
#         Run a cross-validation loop for a model obtained with a given set of
#         hyperparameters.

#         Parameters
#         ----------
#         dataset : dataset object
#             Instance of the `mpol.datasets.GriddedDataset` class
#         Returns
#         -------
#         cv_score : dict
#             Dictionary with mean and standard deviation of cross-validation
#             scores across all k-folds, and all raw scores
#         """
#         all_scores = []
#         if self._store_cv_diagnostics:
#             self._diagnostics = defaultdict(list)

#         split_iterator = self.split_dataset(dataset)
#         if self._split_diag_fig:
#             split_fig, split_axes = split_diagnostics_fig(
#                 split_iterator, save_prefix=self._save_prefix
#             )
#             self._split_figure = (split_fig, split_axes)

#         for kk, (train_set, test_set) in enumerate(split_iterator):
#             if self._verbose:
#                 logging.info("\n  k-fold {} of {}".format(kk, self._kfolds - 1))

#             # if hasattr(self._device,'type') and self._device.type == 'cuda': # TODO: confirm which objects need to be passed to gpu
#             #     train_set, test_set = train_set.to(self._device), test_set.to(self._device)

#             model = GriddedNet(coords=self._coords, nchan=self._imager.nchan)
#             if self._start_dirty_image is True:
#                 if kk == 0:
#                     if self._verbose:
#                         logging.info(
#                             "\n  Pre-training to dirty image to initialize subsequent optimization loops"
#                         )
#                     # initial short training loop to get model image to approximate dirty image
#                     model_pretrained = train_to_dirty_image(
#                         model=model, imager=self._imager
#                     )
#                     # save the model to a state we can load in subsequent kfolds
#                     torch.save(
#                         model_pretrained.state_dict(),
#                         f=self._save_prefix + "_dirty_image_model.pt",
#                     )
#                 else:
#                     # create a new model for this kfold, initializing it to the model pretrained on the dirty image
#                     model.load_state_dict(
#                         torch.load(self._save_prefix + "_dirty_image_model.pt")
#                     )

#             # create a new optimizer and scheduler for this kfold
#             optimizer = torch.optim.Adam(model.parameters(), lr=self._learn_rate)
#             if self._schedule_factor is None:
#                 scheduler = None
#             else:
#                 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
#                     optimizer, mode="min", factor=self._schedule_factor
#                 )

#             trainer = TrainTest(
#                 imager=self._imager,
#                 optimizer=optimizer,
#                 scheduler=scheduler,
#                 epochs=self._epochs,
#                 convergence_tol=self._convergence_tol,
#                 regularizers=self._regularizers,
#                 train_diag_step=self._train_diag_step,
#                 kfold=kk,
#                 save_prefix=self._save_prefix,
#                 verbose=self._verbose,
#             )

#             # run training
#             loss, loss_history = trainer.train(model, train_set)

#             # run testing
#             all_scores.append(trainer.test(model, test_set))

#             # store objects from the most recent kfold for diagnostics
#             self._model = model
#             self._train_figure = trainer.train_figure

#             # collect objects from this kfold to store
#             if self._store_cv_diagnostics:
#                 self._diagnostics["models"].append(self._model)
#                 self._diagnostics["regularizers"].append(self._regularizers)
#                 self._diagnostics["loss_histories"].append(loss_history)
#                 # self._diagnostics["train_figures"].append(self._train_figure)

#         # average individual test scores to get the cross-val metric for chosen
#         # hyperparameters
#         self._cv_score = {
#             "mean": np.mean(all_scores),
#             "std": np.std(all_scores),
#             "all": all_scores,
#         }

#         return self._cv_score

#     @property
#     def model(self):
#         """For the most recent kfold, trained model (`GriddedNet` class instance)"""
#         return self._model

#     @property
#     def regularizers(self):
#         """For the most recent kfold, dict containing regularizers used and their strengths"""
#         return self._regularizers

#     @property
#     def score(self):
#         """Dict containing cross-val scores for all k-folds, and mean and standard deviation of these"""
#         return self._cv_score

#     @property
#     def split_method(self):
#         """String of the method used to split the dataset into train/test sets"""
#         return self._split_method

#     @property
#     def train_figure(self):
#         """For the most recent kfold, (fig, axes) showing training progress"""
#         return self._train_figure

#     @property
#     def split_figure(self):
#         """(fig, axes) of train/test splitting diagnostic figure"""
#         return self._split_figure

#     @property
#     def diagnostics(self):
#         """Dict containing diagnostics of the cross-validation loop across all kfolds: models, regularizers, loss values"""
#         return self._diagnostics


[docs] class RandomCellSplitGridded: r""" Split a GriddedDataset into :math:`k` subsets. Inherit the properties of the GriddedDataset. This object creates an iterator providing a (train, test) pair of :class:`~mpol.datasets.GriddedDataset` for each k-fold. Parameters ---------- dataset : PyTorch dataset object Instance of the `mpol.datasets.GriddedDataset` class k : int, default=5 Number of k-folds (partitions) of `dataset` seed : int, default=None Seed for PyTorch random number generator used to shuffle data before splitting channel : int, default=0 Channel of the dataset to use in determining the splits Notes ----- Once initialized, iterate through the datasets like: >>> split_iterator = crossval.RandomCellSplitGridded(dataset, k) >>> for (train, test) in split_iterator: # iterate through `k` datasets >>> ... # working with the n-th slice of `k` datasets >>> ... # do operations with train dataset >>> ... # do operations with test dataset Treats `dataset` as a single-channel object with all data in `channel`. The splitting doesn't select (preserve) Hermitian pairs of visibilities. All train splits have the highest 1% of cells by gridded weight """ def __init__(self, dataset, k=5, seed=None, channel=0): self.dataset = dataset self.k = k self.channel = channel # get indices for cells in the top 1% of gridded weight # (we'll want all training sets to have these high SNR points) nvis = len(self.dataset.vis_indexed) nn = int(nvis * 0.01) # get the nn-th largest value in weight_indexed w_thresh = np.partition(self.dataset.weight_indexed, -nn)[-nn] self._top_nn = torch.argwhere( self.dataset.weight_gridded[self.channel] >= w_thresh ).T # mask these indices self.top_mask = torch.ones( self.dataset.weight_gridded[self.channel].shape, dtype=bool ) self.top_mask[self._top_nn[0], self._top_nn[1]] = False # use unmasked cells that also have data for splits self.split_mask = torch.logical_and( self.dataset.mask[self.channel], self.top_mask ) split_idx = torch.argwhere(self.split_mask).T # shuffle indices to prevent radial/azimuthal patterns in splits if seed is not None: torch.manual_seed(seed) shuffle = torch.randperm(split_idx.shape[1]) split_idx = split_idx[:, shuffle] # split indices into k subsets self.splits = torch.tensor_split(split_idx, self.k, dim=1) def __iter__(self): # current k-slice self._n = 0 return self def __next__(self): if self._n < self.k: test_idx = self.splits[self._n] train_idx = torch.cat( ([self.splits[x] for x in range(len(self.splits)) if x != self._n]), dim=1, ) # add the masked (high SNR) points to the current training set train_idx = torch.cat((train_idx, self._top_nn), dim=1) train_mask = torch.zeros( self.dataset.weight_gridded[self.channel].shape, dtype=bool ) test_mask = torch.zeros( self.dataset.weight_gridded[self.channel].shape, dtype=bool ) train_mask[train_idx[0], train_idx[1]] = True test_mask[test_idx[0], test_idx[1]] = True # copy original dataset train = copy.deepcopy(self.dataset) test = copy.deepcopy(self.dataset) # use the masks to limit new datasets to only unmasked cells train.add_mask(train_mask) test.add_mask(test_mask) self._n += 1 return train, test else: raise StopIteration
[docs] class DartboardSplitGridded: r""" Split a GriddedDataset into :math:`k` non-overlapping chunks, internally partitioned by a Dartboard. Inherit the properties of the GriddedDataset. This object creates an iterator providing a (train, test) pair of :class:`~mpol.datasets.GriddedDataset` for each k-fold. Args: griddedDataset (:class:`~mpol.datasets.GriddedDataset`): instance of the gridded dataset k (int): the number of subpartitions of the dataset dartboard (:class:`~mpol.datasets.Dartboard`): a pre-initialized Dartboard instance. If ``dartboard`` is provided, do not provide ``q_edges`` or ``phi_edges``. q_edges (1D numpy array): an array of radial bin edges to set the dartboard cells in :math:`[\mathrm{k}\lambda]`. If ``None``, defaults to 12 log-linearly radial bins stretching from 0 to the :math:`q_\mathrm{max}` represented by ``coords``. phi_edges (1D numpy array): an array of azimuthal bin edges to set the dartboard cells in [radians]. If ``None``, defaults to 8 equal-spaced azimuthal bins stretched from :math:`0` to :math:`\pi`. seed (int): (optional) numpy random seed to use for the permutation, for reproducibility Once initialized, iterate through the datasets like >>> cv = crossval.DartboardSplitGridded(dataset, k) >>> for (train, test) in cv: # iterate among k datasets >>> ... # working with the n-th slice of k datasets >>> ... # do operations with train dataset >>> ... # do operations with test dataset Notes: All train splits have the cells belonging to the shortest dartboard baseline bin. The number of points in the splits is in general not equal. """ def __init__( self, gridded_dataset: GriddedDataset, k: int, dartboard: Dartboard | None = None, seed: int | None = None, verbose: bool = True, ): if k <= 0: raise ValueError("k must be a positive integer") if dartboard is None: dartboard = Dartboard(coords=gridded_dataset.coords) self.griddedDataset = gridded_dataset self.k = k self.dartboard = dartboard self.verbose = verbose # 2D mask for any UV cells that contain visibilities # in *any* channel stacked_mask = torch.any(self.griddedDataset.mask, dim=0) # get qs, phis from dataset and turn into 1D lists qs = self.griddedDataset.coords.packed_q_centers_2D[stacked_mask] phis = self.griddedDataset.coords.packed_phi_centers_2D[stacked_mask] # create the full cell_list self.cell_list = self.dartboard.get_nonzero_cell_indices(qs, phis) # indices of cells in the smallest q bin that also have data small_q_idx = [i for i, l in enumerate(self.cell_list) if l[0] == 0] # cells in the smallest q bin self.small_q = self.cell_list[: len(small_q_idx)] # partition the cell_list into k pieces. # first, randomly permute the sequence to make sure # we don't get structured radial/azimuthal patterns. # also exclude the cells belonging to the smallest q bin from all splits # (we'll add these only to the training splits, as they're iterated through) if seed is not None: np.random.seed(seed) self.k_split_cell_list = np.array_split( np.random.permutation(self.cell_list[len(small_q_idx) :]), k )
[docs] @classmethod def from_dartboard_properties( cls, gridded_dataset: GriddedDataset, k: int, q_edges: NDArray[floating[Any]], phi_edges: NDArray[floating[Any]], seed: int | None = None, verbose: bool = True, ) -> DartboardSplitGridded: r""" Alternative method to initialize a DartboardSplitGridded object from Dartboard parameters. Args: griddedDataset (:class:`~mpol.datasets.GriddedDataset`): instance of the gridded dataset k (int): the number of subpartitions of the dataset q_edges (1D numpy array): an array of radial bin edges to set the dartboard cells in :math:`[\mathrm{k}\lambda]`. If ``None``, defaults to 12 log-linearly radial bins stretching from 0 to the :math:`q_\mathrm{max}` represented by ``coords``. phi_edges (1D numpy array): an array of azimuthal bin edges to set the dartboard cells in [radians]. If ``None``, defaults to 8 equal-spaced azimuthal bins stretched from :math:`0` to :math:`\pi`. seed (int): (optional) numpy random seed to use for the permutation, for reproducibility verbose (bool): whether to print notification messages """ dartboard = Dartboard(gridded_dataset.coords, q_edges, phi_edges) return cls(gridded_dataset, k, dartboard, seed, verbose)
def __iter__(self) -> DartboardSplitGridded: self.n = 0 # the current k-slice we're on return self def __next__(self) -> tuple[GriddedDataset, GriddedDataset]: if self.n < self.k: k_list = self.k_split_cell_list.copy() if self.k == 1: if self.verbose is True: logging.info( " DartboardSplitGridded: only 1 k-fold: splitting dataset as ~80/20 train/test" ) ntest = round(0.2 * len(k_list[0])) # put ~20% of cells into test set cell_list_test = k_list[0][:ntest] # remove cells in test set from train set k_list[0] = np.delete(k_list[0], range(ntest), axis=0) else: cell_list_test = k_list.pop(self.n) # put the remaining indices back into a full list cell_list_train = np.concatenate(k_list) # add the smallest q bin cells into the train list cell_list_train = np.append(cell_list_train, self.small_q, axis=0) # create the masks for each cell_list train_mask = self.dartboard.build_grid_mask_from_cells(cell_list_train) test_mask = self.dartboard.build_grid_mask_from_cells(cell_list_test) # copy original dateset train = copy.deepcopy(self.griddedDataset) test = copy.deepcopy(self.griddedDataset) # and use these masks to limit new datasets to only unmasked cells train.add_mask(train_mask) test.add_mask(test_mask) self.n += 1 return train, test else: raise StopIteration