Cross-validation#

class mpol.crossval.RandomCellSplitGridded(dataset, k=5, seed=None, channel=0)[source]#

Split a GriddedDataset into \(k\) subsets. Inherit the properties of the GriddedDataset. This object creates an iterator providing a (train, test) pair of GriddedDataset for each k-fold.

Parameters:
  • dataset (PyTorch dataset object) – Instance of the mpol.datasets.GriddedDataset class

  • k (int, default=5) – Number of k-folds (partitions) of dataset

  • seed (int, default=None) – Seed for PyTorch random number generator used to shuffle data before splitting

  • channel (int, default=0) – Channel of the dataset to use in determining the splits

Notes

Once initialized, iterate through the datasets like:
>>> split_iterator = crossval.RandomCellSplitGridded(dataset, k)
>>> for (train, test) in split_iterator: # iterate through `k` datasets
>>> ... # working with the n-th slice of `k` datasets
>>> ... # do operations with train dataset
>>> ... # do operations with test dataset

Treats dataset as a single-channel object with all data in channel.

The splitting doesn’t select (preserve) Hermitian pairs of visibilities.

All train splits have the highest 1% of cells by gridded weight

class mpol.crossval.DartboardSplitGridded(gridded_dataset: GriddedDataset, k: int, dartboard: Dartboard | None = None, seed: int | None = None, verbose: bool = True)[source]#

Split a GriddedDataset into \(k\) non-overlapping chunks, internally partitioned by a Dartboard. Inherit the properties of the GriddedDataset. This object creates an iterator providing a (train, test) pair of GriddedDataset for each k-fold.

Parameters:
  • griddedDataset (GriddedDataset) – instance of the gridded dataset

  • k (int) – the number of subpartitions of the dataset

  • dartboard (Dartboard) – a pre-initialized Dartboard instance. If dartboard is provided, do not provide q_edges or phi_edges.

  • q_edges (1D numpy array) – an array of radial bin edges to set the dartboard cells in \([\mathrm{k}\lambda]\). If None, defaults to 12 log-linearly radial bins stretching from 0 to the \(q_\mathrm{max}\) represented by coords.

  • phi_edges (1D numpy array) – an array of azimuthal bin edges to set the dartboard cells in [radians]. If None, defaults to 8 equal-spaced azimuthal bins stretched from \(0\) to \(\pi\).

  • seed (int) – (optional) numpy random seed to use for the permutation, for reproducibility

Once initialized, iterate through the datasets like

>>> cv = crossval.DartboardSplitGridded(dataset, k)
>>> for (train, test) in cv: # iterate among k datasets
>>> ... # working with the n-th slice of k datasets
>>> ... # do operations with train dataset
>>> ... # do operations with test dataset

Notes

All train splits have the cells belonging to the shortest dartboard baseline bin.

The number of points in the splits is in general not equal.

classmethod from_dartboard_properties(gridded_dataset: GriddedDataset, k: int, q_edges: ndarray[Any, dtype[floating[Any]]], phi_edges: ndarray[Any, dtype[floating[Any]]], seed: int | None = None, verbose: bool = True) DartboardSplitGridded[source]#

Alternative method to initialize a DartboardSplitGridded object from Dartboard parameters.

Parameters:
  • griddedDataset (GriddedDataset) – instance of the gridded dataset

  • k (int) – the number of subpartitions of the dataset q_edges (1D numpy array): an array of radial bin edges to set the dartboard cells in \([\mathrm{k}\lambda]\). If None, defaults to 12 log-linearly radial bins stretching from 0 to the \(q_\mathrm{max}\) represented by coords.

  • phi_edges (1D numpy array) – an array of azimuthal bin edges to set the dartboard cells in [radians]. If None, defaults to 8 equal-spaced azimuthal bins stretched from \(0\) to \(\pi\). seed (int): (optional) numpy random seed to use for the permutation, for reproducibility

  • verbose (bool) – whether to print notification messages