"""This file is an implementation of model checkpoint callback."""
import os
from os import path
import torch
from scalr.nn.callbacks import CallbackBase
[docs]
class ModelCheckpoint(CallbackBase):
    """Model checkpointing to save model weights at regular intervals.
    Attributes:
        epoch: An interger count of epochs trained.
        max_validation_acc: Keeps track of the maximum validation accuracy across all epochs.
        interval: Regular interval of model checkpointing.
    """
    def __init__(self, dirpath: str, interval: int = 5):
        """Intialize required parameters for model checkpoint callback.
        Args:
            dirpath: To store the respective model checkpoints.
            interval: Regular interval of model checkpointing.
        """
        self.epoch = 0
        self.interval = int(interval)
        self.dirpath = dirpath
        if self.interval:
            os.makedirs(path.join(dirpath, 'checkpoints'), exist_ok=True)
[docs]
    def save_checkpoint(self, model_state_dict: dict, opt_state_dict: dict,
                        path: str):
        """A function to save model & optimizer state dict to the given path.
        
        Args:
            model_state_dict: Model's state dict.
            opt_state_dict: Optimizer's state dict.
            path: Path to store checkpoint to.
        """
        torch.save(
            {
                'epoch': self.epoch,
                'model_state_dict': model_state_dict,
                'optimizer_state_dict': opt_state_dict
            }, path) 
    def __call__(self, model_state_dict: dict, opt_state_dict: dict, **kwargs):
        """A function that evaluates when to save a checkpoint.
        Args:
            model_state_dict: Model's state dict.
            opt_state_dict: Optimizer's state dict.
        """
        self.epoch += 1
        if self.interval and self.epoch % self.interval == 0:
            self.save_checkpoint(
                model_state_dict, opt_state_dict,
                path.join(self.dirpath, 'checkpoints',
                          f'model_{self.epoch}.pt'))
[docs]
    @classmethod
    def get_default_params(cls):
        """Class method to get default params for model_config."""
        return dict(dirpath='.', interval=5)