"""This file is an implementation of early stopping callback."""
import os
from os import path
import torch
from scalr.nn.callbacks import CallbackBase
[docs]
class EarlyStopping(CallbackBase):
    """
    Implements early stopping based upon validation loss.
    Attributes:
        patience: Number of epochs with no improvement after which training will be stopped.
        min_delta: Minimum change in the monitored quantity to qualify as an improvement,
        i.e. an absolute change of less than min_delta, will count as no improvement.
    """
    def __init__(self,
                 dirpath: str = None,
                 patience: int = 3,
                 min_delta: float = 1e-4):
        """Intialize required parameters for early stopping callback.
        Args:
            patience: Number of epochs with no improvement after which training will be stopped.
            min_delta: Minimum change in the monitored quantity to qualify as an improvement,
                            i.e. an absolute change of less than min_delta, will count as no improvement.
            epoch: An interger count of epochs trained.
            min_validation_loss: Keeps track of the minimum validation loss across all epochs.
        """
        self.patience = int(patience)
        self.min_delta = float(min_delta)
        self.epoch = 0
        self.min_val_loss = float('inf')
    def __call__(self, val_loss: float, **kwargs) -> bool:
        """Return `True` if model training needs to be stopped based upon improvement conditions.
        Else returns `False` for continued training.
        """
        if val_loss < self.min_val_loss:
            self.min_val_loss = val_loss
            self.epoch = 0
        elif val_loss >= (self.min_val_loss + self.min_delta):
            self.epoch += 1
            if self.epoch >= self.patience:
                return True
        return False
[docs]
    @classmethod
    def get_default_params(cls):
        """Class method to get default params for model_config."""
        return dict(patience=3, min_delta=1e-4)