Source code for scalr.nn.callbacks.model_checkpoint

"""This file is an implementation of model checkpoint callback."""

import os
from os import path

import torch

from scalr.nn.callbacks import CallbackBase


[docs] class ModelCheckpoint(CallbackBase): """Model checkpointing to save model weights at regular intervals. Attributes: epoch: An interger count of epochs trained. max_validation_acc: Keeps track of the maximum validation accuracy across all epochs. interval: Regular interval of model checkpointing. """ def __init__(self, dirpath: str, interval: int = 5): """Intialize required parameters for model checkpoint callback. Args: dirpath: To store the respective model checkpoints. interval: Regular interval of model checkpointing. """ self.epoch = 0 self.interval = int(interval) self.dirpath = dirpath if self.interval: os.makedirs(path.join(dirpath, 'checkpoints'), exist_ok=True)
[docs] def save_checkpoint(self, model_state_dict: dict, opt_state_dict: dict, path: str): """A function to save model & optimizer state dict to the given path. Args: model_state_dict: Model's state dict. opt_state_dict: Optimizer's state dict. path: Path to store checkpoint to. """ torch.save( { 'epoch': self.epoch, 'model_state_dict': model_state_dict, 'optimizer_state_dict': opt_state_dict }, path)
def __call__(self, model_state_dict: dict, opt_state_dict: dict, **kwargs): """A function that evaluates when to save a checkpoint. Args: model_state_dict: Model's state dict. opt_state_dict: Optimizer's state dict. """ self.epoch += 1 if self.interval and self.epoch % self.interval == 0: self.save_checkpoint( model_state_dict, opt_state_dict, path.join(self.dirpath, 'checkpoints', f'model_{self.epoch}.pt'))
[docs] @classmethod def get_default_params(cls): """Class method to get default params for model_config.""" return dict(dirpath='.', interval=5)