Source code for scalr.nn.callbacks.early_stopping

"""This file is an implementation of early stopping callback."""

import os
from os import path

import torch

from scalr.nn.callbacks import CallbackBase


[docs] class EarlyStopping(CallbackBase): """ Implements early stopping based upon validation loss. Attributes: patience: Number of epochs with no improvement after which training will be stopped. min_delta: Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement. """ def __init__(self, dirpath: str = None, patience: int = 3, min_delta: float = 1e-4): """Intialize required parameters for early stopping callback. Args: patience: Number of epochs with no improvement after which training will be stopped. min_delta: Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement. epoch: An interger count of epochs trained. min_validation_loss: Keeps track of the minimum validation loss across all epochs. """ self.patience = int(patience) self.min_delta = float(min_delta) self.epoch = 0 self.min_val_loss = float('inf') def __call__(self, val_loss: float, **kwargs) -> bool: """Return `True` if model training needs to be stopped based upon improvement conditions. Else returns `False` for continued training. """ if val_loss < self.min_val_loss: self.min_val_loss = val_loss self.epoch = 0 elif val_loss >= (self.min_val_loss + self.min_delta): self.epoch += 1 if self.epoch >= self.patience: return True return False
[docs] @classmethod def get_default_params(cls): """Class method to get default params for model_config.""" return dict(patience=3, min_delta=1e-4)