Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial attempt at implementing resnet50 for use with CIFAR data. #7

Merged
merged 2 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,7 @@ _html/
.initialize_new_project.sh

# Model files
**/*.pth
**/*.pth

# Run results
results/
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ repos:
name: Clear output from Jupyter notebooks
description: Clear output from Jupyter notebooks.
files: \.ipynb$
exclude: ^docs/pre_executed
stages: [commit]
language: system
entry: jupyter nbconvert --clear-output
Expand Down
1,701 changes: 1,701 additions & 0 deletions docs/pre_executed/CNN_filter.ipynb

Large diffs are not rendered by default.

18 changes: 10 additions & 8 deletions example_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ log_level = "info" # Emit informational messages, warnings and all errors
# log_level = "debug" # Very verbose, emit all log messages.

data_dir = "/home/drew/code/fibad/data/"
results_dir = "./results" # Results get named <verb>-<timestamp> under this directory

[download]
sw = "22asec"
Expand Down Expand Up @@ -52,20 +53,20 @@ mask = false
[model]
# The name of the built-in model to use or the libpath to an external model
# e.g. "user_package.submodule.ExternalModel" or "ExampleAutoencoder"
name = "kbmod_ml.models.cnn.CNN"
name = "kbmod_ml.models.resnet50.RESNET50"

weights_filepath = "example_model.pth"
weights_filepath = "resnet50.pth"
epochs = 10

base_channel_size = 32
latent_dim =64
num_classes = 10

[data_loader]

[data_set]
# Name of the built-in data loader to use or the libpath to an external data loader
# e.g. "user_package.submodule.ExternalDataLoader" or "HSCDataLoader"
name = "CifarDataLoader"


[data_loader]
# Pixel dimensions used to crop all images prior to loading. Will prune any images that are too small.
#
# If not provided by user, the default of 'false' scans the directory for the smallest dimensioned files, and
Expand All @@ -83,9 +84,10 @@ crop_to = false
filters = false

# Default PyTorch DataLoader parameters
batch_size = 4
batch_size = 10
shuffle = true
num_workers = 2
num_workers = 10

[predict]
model_weights_file = false
batch_size = 32
65 changes: 65 additions & 0 deletions src/kbmod_ml/models/resnet50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# ruff: noqa: D101, D102

# This example model is taken from the PyTorch CIFAR10 tutorial:
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-convolutional-neural-network
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F # noqa N812
import torch.optim as optim
from fibad.models.model_registry import fibad_model
from torchvision.models import resnet50

logger = logging.getLogger(__name__)


@fibad_model
class RESNET50(nn.Module):
def __init__(self, model_config, shape):
logger.info("This is an external model, not in FIBAD!!!")
super().__init__()

self.config = model_config

self.model = resnet50(pretrained=False, num_classes=self.config["model"]["num_classes"])

# Optimizer and criterion could be set directly, i.e. `self.optimizer = optim.SGD(...)`
# but we define them as methods as a way to allow for more flexibility in the future.
self.optimizer = self._optimizer()
self.criterion = self._criterion()

def forward(self, x):
return self.model(x)

def train_step(self, batch):
"""This function contains the logic for a single training step. i.e. the
contents of the inner loop of a ML training process.

Parameters
----------
batch : tuple
A tuple containing the inputs and labels for the current batch.

Returns
-------
Current loss value
The loss value for the current batch.
"""
inputs, labels = batch

self.optimizer.zero_grad()
outputs = self(inputs)
loss = self.criterion(outputs, labels)
loss.backward()
self.optimizer.step()
return {"loss": loss.item()}

def _criterion(self):
return nn.CrossEntropyLoss()

def _optimizer(self):
return optim.SGD(self.parameters(), lr=0.001, momentum=0.9)

def save(self):
torch.save(self.state_dict(), self.config.get("weights_filepath"))
Loading