trainer
_trainer
This file is a base class for a model trainer.
- class scalr.nn.trainer._trainer.TrainerBase(model: Module, opt: Optimizer, loss_fn: Module, callbacks: CallbackExecutor, device: str = 'cpu')[source]
Bases:
object
Class for a model trainer. It trains and validates a model.
- train(epochs: int, train_dl: DataLoader, val_dl: DataLoader)[source]
This function trains the model, and executes callbacks.
- Parameters:
epochs – Max number of epochs to train model on.
train_dl – Training dataloader.
val_dl – Validation dataloader.
simple_model_trainer
This file is a wrapper for Model trainer base class.
- class scalr.nn.trainer.simple_model_trainer.SimpleModelTrainer(*args, **kwargs)[source]
Bases:
TrainerBase
Class for Simple model trainer.
It works with dataloaders which contain all input tensors in line with model input, and the last tensor as target to train the model.