Source code for mpol.training

import logging

import torch


[docs] def train_to_dirty_image(model, imager, robust=0.5, learn_rate=100, niter=1000): r""" Train against a dirty image of the observed visibilities using a loss function of the mean squared error between the RML model image pixel fluxes and the dirty image pixel fluxes. Useful for initializing a separate RML optimization loop at a reasonable starting image. Parameters ---------- model : `torch.nn.Module` object A neural network module; instance of the `mpol.precomposed.GriddedNet` class. imager : :class:`mpol.gridding.DirtyImager` object Instance of the `mpol.gridding.DirtyImager` class. robust : float, default=0.5 Robust weighting parameter used to create a dirty image. learn_rate : float, default=100 Learning rate for optimization loop niter : int, default=1000 Number of iterations for optimization loop Returns ------- model : `torch.nn.Module` object The input `model` updated with the state of the training to the dirty image """ logging.info(" Initializing model to dirty image") img, beam = imager.get_dirty_image( weighting="briggs", robust=robust, unit="Jy/arcsec^2" ) dirty_image = torch.tensor(img.copy()) optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate) losses = [] for ii in range(niter): optimizer.zero_grad() model() sky_cube = model.icube.sky_cube lossfunc = torch.nn.MSELoss(reduction="sum") # MSELoss calculates mean squared error (squared L2 norm), so sqrt it loss = (lossfunc(sky_cube, dirty_image)) ** 0.5 losses.append(loss.item()) loss.backward() optimizer.step() return model
# class TrainTest: # r""" # Utilities for training and testing an MPoL neural network. # Args: # imager (:class:`mpol.gridding.DirtyImager` object): Instance of the `mpol.gridding.DirtyImager` class. # optimizer (:class:`torch.optim` object): PyTorch optimizer class for the training loop. # scheduler (:class:`torch.optim.lr_scheduler` object, default=None): Scheduler for adjusting learning rate during optimization. # regularizers (nested dict): Dictionary of image regularizers to use. For each, a dict of the strength ('lambda', float), whether to guess an initial value for lambda ('guess', bool), and other quantities needed to compute their loss term. # Example: # ``{"sparsity":{"lambda":1e-3, "guess":False}, # "entropy": {"lambda":1e-3, "guess":True, "prior_intensity":1e-10} # }`` # epochs (int): Number of training iterations, default=10000 # convergence_tol (float): Tolerance for training iteration stopping criterion as assessed by # loss function (suggested <= 1e-3) # train_diag_step (int): Interval at which training diagnostics are output. If None, no diagnostics will be generated. # kfold (int): The k-fold of the current training set (for diagnostics) # save_prefix (str): Prefix (path) used for saved figure names. If None, figures won't be saved # verbose (bool): Whether to print notification messages # """ # def __init__( # self, # imager, # optimizer, # scheduler=None, # regularizers={}, # epochs=10000, # convergence_tol=1e-5, # train_diag_step=None, # kfold=None, # save_prefix=None, # verbose=True, # ): # self._imager = imager # self._optimizer = optimizer # self._scheduler = scheduler # self._regularizers = regularizers # self._epochs = epochs # self._convergence_tol = convergence_tol # self._train_diag_step = train_diag_step # self._kfold = kfold # self._save_prefix = save_prefix # self._verbose = verbose # self._train_figure = None # def loss_convergence(self, loss): # r""" # Estimate whether the loss function has converged by assessing its # relative change over recent iterations. # Parameters # ---------- # loss : array # Values of loss function over iterations (epochs). # If len(loss) < 11, `False` will be returned, as convergence # cannot be adequately assessed. # Returns # ------- # `True` if the convergence criterion is met, else `False`. # """ # min_len = 11 # if len(loss) < min_len: # return False # ratios = np.abs(loss[-1] / loss[-min_len:-1]) # return all(1 - self._convergence_tol <= ratios) and all( # ratios <= 1 + self._convergence_tol # ) # def loss_lambda_guess(self): # r""" # Set an initial guess for regularizer strengths :math:`\lambda_{x}` by # comparing images generated with different visibility weighting. # The guesses update `lambda` values in `self._regularizers`. # """ # if self._verbose: # logging.info( # " Updating regularizer strengths with automated " # f"guessing. Initial values: {self._regularizers}" # ) # # generate images of the data using two briggs robust values # img1, _ = self._imager.get_dirty_image(weighting="briggs", robust=0.0) # img2, _ = self._imager.get_dirty_image(weighting="briggs", robust=0.5) # img1 = torch.from_numpy(img1.copy()) # img2 = torch.from_numpy(img2.copy()) # if self._regularizers.get("entropy", {}).get("guess") == True: # # force negative pixel values to small positive value # img1_nn = torch.where(img1 < 0, 1e-10, img1) # img2_nn = torch.where(img2 < 0, 1e-10, img2) # loss_e1 = entropy(img1_nn, self._regularizers["entropy"]["prior_intensity"]) # loss_e2 = entropy(img2_nn, self._regularizers["entropy"]["prior_intensity"]) # guess_e = 1 / (loss_e2 - loss_e1) # # update stored value # self._regularizers["entropy"]["lambda"] = guess_e.numpy().item() # if self._regularizers.get("sparsity", {}).get("guess") == True: # loss_s1 = sparsity(img1) # loss_s2 = sparsity(img2) # guess_s = 1 / (loss_s2 - loss_s1) # self._regularizers["sparsity"]["lambda"] = guess_s.numpy().item() # if self._regularizers.get("TV", {}).get("guess") == True: # loss_TV1 = TV_image(img1, self._regularizers["TV"]["epsilon"]) # loss_TV2 = TV_image(img2, self._regularizers["TV"]["epsilon"]) # guess_TV = 1 / (loss_TV2 - loss_TV1) # self._regularizers["TV"]["lambda"] = guess_TV.numpy().item() # if self._regularizers.get("TSV", {}).get("guess") == True: # loss_TSV1 = TSV(img1) # loss_TSV2 = TSV(img2) # guess_TSV = 1 / (loss_TSV2 - loss_TSV1) # self._regularizers["TSV"]["lambda"] = guess_TSV.numpy().item() # if self._verbose: # logging.info(f" Updated values: {self._regularizers}") # def loss_eval(self, vis, dataset, sky_cube=None): # r""" # Parameters # ---------- # vis : torch.complex tensor # Model visibility cube (see `mpol.fourier.FourierCube.forward`) # dataset : dataset object # Instance of the `mpol.datasets.GriddedDataset` class. # sky_cube : torch.double # MPoL Ground Cube (see `mpol.utils.packed_cube_to_ground_cube`) # Returns # ------- # loss : torch.double # Value of loss function # """ # # negative log-likelihood loss function # loss = r_chi_squared_gridded(vis, dataset) # # regularizers # if sky_cube is not None: # if self._regularizers.get("entropy", {}).get("lambda") is not None: # loss += self._regularizers["entropy"]["lambda"] * entropy( # sky_cube, self._regularizers["entropy"]["prior_intensity"] # ) # if self._regularizers.get("sparsity", {}).get("lambda") is not None: # loss += self._regularizers["sparsity"]["lambda"] * sparsity(sky_cube) # if self._regularizers.get("TV", {}).get("lambda") is not None: # loss += self._regularizers["TV"]["lambda"] * TV_image( # sky_cube, self._regularizers["TV"]["epsilon"] # ) # if self._regularizers.get("TSV", {}).get("lambda") is not None: # loss += self._regularizers["TSV"]["lambda"] * TSV(sky_cube) # return loss # def train(self, model, dataset): # r""" # Trains a neural network, forward modeling a visibility dataset and # evaluating the corresponding model image against the data, using # PyTorch with gradient descent. # Parameters # ---------- # model : `torch.nn.Module` object # A neural network module; instance of the `mpol.precomposed.GriddedNet` class. # dataset : PyTorch dataset object # Instance of the `mpol.datasets.GriddedDataset` class. # Returns # ------- # loss.item() : float # Value of loss function at end of optimization loop # losses : list of float # Loss value at each iteration (epoch) in the loop # """ # # set model to training mode # model.train() # count = 0 # fluxes = [] # losses = [] # learn_rates = [] # old_mod_im = None # old_mod_epoch = None # run_loss_guess = False # for _, v in self._regularizers.items(): # for i in v: # if v[i] is True: # run_loss_guess = True # if run_loss_guess: # # guess initial strengths for regularizers in `self._regularizers` # # that have 'guess':True # # (this updates `self._regularizers`) # self.loss_lambda_guess() # if self._verbose: # logging.info(" Image regularizers: {}".format(self._regularizers)) # while not self.loss_convergence(np.array(losses)) and count <= self._epochs: # if self._verbose: # print( # "\r Training: epoch {} of {}".format(count, self._epochs), # end="", # flush=True, # ) # # check early-on whether the loss isn't evolving # if count == 10: # loss_arr = np.array(losses) # if all(0.9 <= loss_arr[:-1] / loss_arr[1:]) and all( # loss_arr[:-1] / loss_arr[1:] <= 1.1 # ): # warn_msg = ( # "The loss function is negligibly evolving. loss_rate " # + "may be too low." # ) # logging.info(warn_msg) # raise Warning(warn_msg) # self._optimizer.zero_grad() # # calculate model visibility cube (corresponding to current pixel # # values of mpol.images.BaseCube) # vis = model() # # get predicted sky cube corresponding to model visibilities # sky_cube = model.icube.sky_cube # # total flux in model image # total_flux = model.coords.cell_size**2 * torch.sum(sky_cube) # fluxes.append(torch2npy(total_flux)) # # calculate loss between model visibilities and data # loss = self.loss_eval(vis, dataset, sky_cube) # losses.append(loss.item()) # # calculate gradients of loss function w.r.t. model parameters # loss.backward() # # update model parameters via gradient descent # self._optimizer.step() # if self._scheduler is not None: # self._scheduler.step(loss) # learn_rates.append(self._optimizer.param_groups[0]["lr"]) # # generate optional fit diagnostics # if self._train_diag_step is not None and ( # count % self._train_diag_step == 0 # or count == self._epochs # or self.loss_convergence(np.array(losses)) # ): # train_fig, train_axes = train_diagnostics_fig( # model, # losses=losses, # learn_rates=learn_rates, # fluxes=fluxes, # old_model_image=old_mod_im, # old_model_epoch=old_mod_epoch, # kfold=self._kfold, # epoch=count, # save_prefix=self._save_prefix, # ) # self._train_figure = (train_fig, train_axes) # # temporarily store the current model image for use in next call to `train_diagnostics_fig` # old_mod_im = torch2npy(model.icube.sky_cube[0]) # old_mod_epoch = count * 1 # count += 1 # if self._verbose: # if count < self._epochs: # logging.info( # "\n Loss function convergence criterion met at epoch " # "{}".format(count - 1) # ) # else: # logging.info( # "\n Loss function convergence criterion not met; " # "training stopped at specified maximum epochs, {}".format( # self._epochs # ) # ) # # return loss value # return loss.item(), losses # def test(self, model, dataset): # r""" # Test a model visibility cube against withheld data. # Parameters # ---------- # model : `torch.nn.Module` object # A neural network module; instance of the `mpol.precomposed.GriddedNet` class. # dataset : PyTorch dataset object # Instance of the `mpol.datasets.GriddedDataset` class. # Returns # ------- # loss.item() : float # Value of loss function # """ # # evaluate trained model against a set of withheld (test) visibilities # model.eval() # # calculate model visibility cube # vis = model() # # calculate loss used for a cross-validation score # loss = self.loss_eval(vis, dataset) # # return loss value # return loss.item() # @property # def regularizers(self): # """Dict containing regularizers used and their strengths""" # return self._regularizers # @property # def train_figure(self): # """(fig, axes) of figure showing training diagnostics""" # return self._train_figure