scalr.nn.callbacks package

Submodules

scalr.nn.callbacks._callbacks module

This file is a base class for implementation of Callbacks.

class scalr.nn.callbacks._callbacks.CallbackBase(dirpath='.')[source]

Bases: object

Base class to build callbacks.

classmethod get_default_params()[source]

Class method to get default params for callbacks config.

class scalr.nn.callbacks._callbacks.CallbackExecutor(dirpath: str, callbacks: list[dict])[source]

Bases: object

Wrapper class to execute all enabled callbacks.

Enabled callbacks are executed with the early stopping callback executed last to return a flag for continuation or stopping of model training

execute(**kwargs) bool[source]

Execute all the enabled callbacks. Returns early stopping condition.

scalr.nn.callbacks.early_stopping module

This file is an implementation of early stopping callback.

class scalr.nn.callbacks.early_stopping.EarlyStopping(dirpath: str | None = None, patience: int = 3, min_delta: float = 0.0001)[source]

Bases: CallbackBase

Implements early stopping based upon validation loss.

patience

Number of epochs with no improvement after which training will be stopped.

min_delta

Minimum change in the monitored quantity to qualify as an improvement,

i.e. an absolute change of less than min_delta, will count as no improvement.
classmethod get_default_params()[source]

Class method to get default params for model_config.

scalr.nn.callbacks.model_checkpoint module

This file is an implementation of model checkpoint callback.

class scalr.nn.callbacks.model_checkpoint.ModelCheckpoint(dirpath: str, interval: int = 5)[source]

Bases: CallbackBase

Model checkpointing to save model weights at regular intervals.

epoch

An interger count of epochs trained.

max_validation_acc

Keeps track of the maximum validation accuracy across all epochs.

interval

Regular interval of model checkpointing.

classmethod get_default_params()[source]

Class method to get default params for model_config.

save_checkpoint(model_state_dict: dict, opt_state_dict: dict, path: str)[source]

A function to save model & optimizer state dict to the given path.

Parameters:
  • model_state_dict – Model’s state dict.

  • opt_state_dict – Optimizer’s state dict.

  • path – Path to store checkpoint to.

scalr.nn.callbacks.tensorboard_logger module

This file is an implementation of Tensorboard logging callback.

class scalr.nn.callbacks.tensorboard_logger.TensorboardLogger(dirpath: str = '.')[source]

Bases: CallbackBase

Tensorboard logging of the training process.

epoch

An interger count of epochs trained.

writer

Object that writes to tensorboard.

classmethod get_default_params()[source]

Class method to get default params for model_config.

scalr.nn.callbacks.test_early_stopping module

This is a test file for early_stopping.py

scalr.nn.callbacks.test_early_stopping.test_early_stopping()[source]

This function tests early stopping of the model.

Module contents