Trainer

_trainer

This file is a base class for a model trainer.

class scalr.nn.trainer._trainer.TrainerBase(model: Module, opt: Optimizer, loss_fn: Module, callbacks: CallbackExecutor, device: str = 'cpu')[source]

Bases: object

Class for a model trainer. It trains and validates a model.

train(epochs: int, train_dl: DataLoader, val_dl: DataLoader)[source]

This function trains the model, and executes callbacks.

Parameters:
  • epochs – Max number of epochs to train model on.

  • train_dl – Training dataloader.

  • val_dl – Validation dataloader.

train_one_epoch(dl: DataLoader) tuple[float, float][source]

This function trains the model for one epoch.

Parameters:

dl – Training dataloader.

Returns:

Train Loss, Train Accuracy.

validation(dl: DataLoader) tuple[float, float][source]

This function performs validation of the data.

Parameters:

dl – Validation dataloader.

Returns:

Validation Loss, Validation Accuracy.

simple_model_trainer

This file is a wrapper for Model trainer base class.

class scalr.nn.trainer.simple_model_trainer.SimpleModelTrainer(*args, **kwargs)[source]

Bases: TrainerBase

Class for Simple model trainer.

It works with dataloaders which contain all input tensors in line with model input, and the last tensor as target to train the model.