Skip to content

theSoenke/pytorch-trainer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

47 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch Trainer

Lightweight wrapper around PyTorch. Removes boilerplate code to focus on the important parts.

Example

import os

import torch
import torchvision.transforms as transforms
from module import Module
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

from pytorch_trainer import EarlyStopping, ModelCheckpoint, Module, Trainer

class MNISTModel(Module):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_num):
        x, y = batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        return {'loss': loss}

    def validation_step(self, batch, batch_num):
        x, y = batch
        output = self.forward(x)
        return {'val_loss': F.cross_entropy(output, y)}

    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'val_loss': avg_loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

    def train_dataloader(self):
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

    def val_dataloader(self):
        return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)


checkpoint_callback = ModelCheckpoint(
    directory='./checkpoints',
    monitor='val_loss',
    save_best_only=True,
    mode='min'
)
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.00,
    patience=5,
    mode='min'
)

model = MNISTModel()
trainer = Trainer(
    checkpoint_callback=checkpoint_callback,
    early_stop_callback=early_stop_callback,
)
trainer.fit(model)

Inspired by PyTorch Lightning

About

Lightweight PyTorch trainer

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages