Source code for scalr.nn.trainer.simple_model_trainer

"""This file is a wrapper for Model trainer base class."""

import torch
from torch import nn
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from scalr.nn.callbacks import CallbackExecutor
from scalr.nn.trainer import TrainerBase


[docs] class SimpleModelTrainer(TrainerBase): """Class for Simple model trainer. It works with dataloaders which contain all input tensors in line with model input, and the last tensor as target to train the model. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)