Source code for scalr.nn.callbacks.tensorboard_logger

"""This file is an implementation of Tensorboard logging callback."""

import os
from os import path

import torch
from torch.utils.tensorboard import SummaryWriter

from scalr.nn.callbacks import CallbackBase


[docs] class TensorboardLogger(CallbackBase): """ Tensorboard logging of the training process. Attributes: epoch: An interger count of epochs trained. writer: Object that writes to tensorboard. """ def __init__(self, dirpath: str = '.'): """Intialize required parameters for tensorboard logging callback. Args: dirpath: Path of directory to store the experiment logs. """ self.writer = SummaryWriter(path.join(dirpath, 'logs')) self.epoch = 0 def __call__(self, train_loss: float, train_acc: float, val_loss: float, val_acc: float, **kwargs): """Logs the train_loss, val_loss, train_accuracy, val_accuracy for each epoch.""" self.epoch += 1 self.writer.add_scalars('Loss', { 'train': train_loss, 'val': val_loss }, self.epoch) self.writer.add_scalars('Accuracy', { 'train': train_acc, 'val': val_acc }, self.epoch)
[docs] @classmethod def get_default_params(cls): """Class method to get default params for model_config.""" return dict(dirpath='.')