Note: If you are looking for something more extensive, checkout Pytorch Lightning. This is mostly designed for my personal use.
- torch
- tqdm
pip install torch-runner
- seed all variables
- text logger
- early stopping
- save hyperparameters
- weights & biases support
Checkout the examples folder which contains a jupyter notebook to train a resnet50 using torch_runner.
import torch
import torch_runner as T
class myTrainer(T.TrainerModule):
def calc_metric(self, preds, target):
## Calc metrics such as accuracy etc.
def loss_fct(self, preds, target):
## Calc loss
def train_one_step(self, batch, batch_id):
## Get batch data from dataloader and perform one update
def valid_one_step(self, batch, batch_id):
## Perform validation step
config = T.TrainerConfig()
model = myModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
train_dataloader = ## pytorch dataloader
val_dataloader = ## pytorch dataloader
Trainer = myTrainer(model, optimizer, config)
Trainer.fit(train_dataloader, val_dataloader, epochs=10)