Cross-validation#
- class mpol.crossval.RandomCellSplitGridded(dataset, k=5, seed=None, channel=0)[source]#
Split a GriddedDataset into \(k\) subsets. Inherit the properties of the GriddedDataset. This object creates an iterator providing a (train, test) pair of
GriddedDataset
for each k-fold.- Parameters:
dataset (PyTorch dataset object) – Instance of the mpol.datasets.GriddedDataset class
k (int, default=5) – Number of k-folds (partitions) of dataset
seed (int, default=None) – Seed for PyTorch random number generator used to shuffle data before splitting
channel (int, default=0) – Channel of the dataset to use in determining the splits
Notes
- Once initialized, iterate through the datasets like:
>>> split_iterator = crossval.RandomCellSplitGridded(dataset, k) >>> for (train, test) in split_iterator: # iterate through `k` datasets >>> ... # working with the n-th slice of `k` datasets >>> ... # do operations with train dataset >>> ... # do operations with test dataset
Treats dataset as a single-channel object with all data in channel.
The splitting doesn’t select (preserve) Hermitian pairs of visibilities.
All train splits have the highest 1% of cells by gridded weight
- class mpol.crossval.DartboardSplitGridded(gridded_dataset: GriddedDataset, k: int, dartboard: Dartboard | None = None, seed: int | None = None, verbose: bool = True)[source]#
Split a GriddedDataset into \(k\) non-overlapping chunks, internally partitioned by a Dartboard. Inherit the properties of the GriddedDataset. This object creates an iterator providing a (train, test) pair of
GriddedDataset
for each k-fold.- Parameters:
griddedDataset (
GriddedDataset
) – instance of the gridded datasetk (int) – the number of subpartitions of the dataset
dartboard (
Dartboard
) – a pre-initialized Dartboard instance. Ifdartboard
is provided, do not provideq_edges
orphi_edges
.q_edges (1D numpy array) – an array of radial bin edges to set the dartboard cells in \([\mathrm{k}\lambda]\). If
None
, defaults to 12 log-linearly radial bins stretching from 0 to the \(q_\mathrm{max}\) represented bycoords
.phi_edges (1D numpy array) – an array of azimuthal bin edges to set the dartboard cells in [radians]. If
None
, defaults to 8 equal-spaced azimuthal bins stretched from \(0\) to \(\pi\).seed (int) – (optional) numpy random seed to use for the permutation, for reproducibility
Once initialized, iterate through the datasets like
>>> cv = crossval.DartboardSplitGridded(dataset, k) >>> for (train, test) in cv: # iterate among k datasets >>> ... # working with the n-th slice of k datasets >>> ... # do operations with train dataset >>> ... # do operations with test dataset
Notes
All train splits have the cells belonging to the shortest dartboard baseline bin.
The number of points in the splits is in general not equal.
- classmethod from_dartboard_properties(gridded_dataset: GriddedDataset, k: int, q_edges: ndarray[Any, dtype[floating[Any]]], phi_edges: ndarray[Any, dtype[floating[Any]]], seed: int | None = None, verbose: bool = True) DartboardSplitGridded [source]#
Alternative method to initialize a DartboardSplitGridded object from Dartboard parameters.
- Parameters:
griddedDataset (
GriddedDataset
) – instance of the gridded datasetk (int) – the number of subpartitions of the dataset q_edges (1D numpy array): an array of radial bin edges to set the dartboard cells in \([\mathrm{k}\lambda]\). If
None
, defaults to 12 log-linearly radial bins stretching from 0 to the \(q_\mathrm{max}\) represented bycoords
.phi_edges (1D numpy array) – an array of azimuthal bin edges to set the dartboard cells in [radians]. If
None
, defaults to 8 equal-spaced azimuthal bins stretched from \(0\) to \(\pi\). seed (int): (optional) numpy random seed to use for the permutation, for reproducibilityverbose (bool) – whether to print notification messages