Skip to content

Commit

Permalink
Merge pull request #15 from LucianoDeben/feature/DDR
Browse files Browse the repository at this point in the history
Add complete DRR functionality with example notebooks
  • Loading branch information
LucianoDeben authored May 2, 2024
2 parents 5822c62 + 81d6b43 commit f0ad338
Show file tree
Hide file tree
Showing 16 changed files with 1,414 additions and 641 deletions.
1 change: 1 addition & 0 deletions DiffDRR
Submodule DiffDRR added at 8d3248
Binary file removed Z-values_train.png
Binary file not shown.
4 changes: 0 additions & 4 deletions config.py

This file was deleted.

996 changes: 996 additions & 0 deletions notebooks/diffdrr.ipynb

Large diffs are not rendered by default.

504 changes: 0 additions & 504 deletions notebooks/drr.ipynb

This file was deleted.

221 changes: 98 additions & 123 deletions notebooks/preprocessing.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion notebooks/segmentation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.11.8"
}
},
"nbformat": 4,
Expand Down
Empty file added src/__init__.py
Empty file.
9 changes: 9 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class Config:
learning_rate = 0.001
batch_size = 4
num_epochs = 10
val_frac = 0.2
seed = 42


config = Config()
22 changes: 21 additions & 1 deletion src/custom_transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import os
from typing import Hashable, Optional

import numpy as np
import torch
from monai.transforms import MapTransform
from monai.config import KeysCollection
from monai.transforms import MapTransform, SaveImaged


class UndoOneHotEncoding(MapTransform):
Expand Down Expand Up @@ -50,6 +55,21 @@ def __call__(self, data):
return data


class ConvertToSingleChannel(MapTransform):
def __init__(self, keys):
super().__init__(keys)

def __call__(self, data):
for key in self.keys:
# Stack along the channel dimension
data[key] = np.argmax(data[key], axis=0).astype(np.int32)

# Unsqueeze the channel dimension
data[key] = torch.tensor(data[key]).unsqueeze(0)

return data


class RemoveNecrosisChannel(MapTransform):
"""
Remove the necrosis channel from the segmentation
Expand Down
73 changes: 73 additions & 0 deletions src/drr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Import libraries
import sys

sys.path.append("..")

import torch

from diffdrr.data import read
from diffdrr.drr import DRR
from diffdrr.pose import convert


def create_drr(
volume,
segmentation,
bone_attenuation_multiplier=5.0,
sdd=1020,
height=200,
width=200,
delx=2.0,
dely=2.0,
x0=0,
y0=0,
p_subsample=None,
reshape=True,
reverse_x_axis=True,
patch_size=None,
renderer="siddon",
rotations=torch.tensor([[0.0, 0.0, 0.0]]),
rotations_degrees=True,
translations=torch.tensor([[0.0, 850.0, 0.0]]),
mask_to_channels=True,
device="cpu",
):

# Read the image and segmentation subject
subject = read(
tensor=volume,
label_tensor=segmentation,
orientation="AP",
bone_attenuation_multiplier=bone_attenuation_multiplier,
)

# Create a DRR object
drr = DRR(
subject, # A torchio.Subject object storing the CT volume, origin, and voxel spacing
sdd=sdd, # Source-to-detector distance (i.e., the C-arm's focal length)
height=height, # Height of the DRR (if width is not seperately provided, the generated image is square)
width=width, # Width of the DRR
delx=delx, # Pixel spacing (in mm)
dely=dely, # Pixel spacing (in mm)
x0=x0, # # Principal point X-offset
y0=y0, # Principal point Y-offset
p_subsample=p_subsample, # Proportion of pixels to randomly subsample
reshape=reshape, # Return DRR with shape (b, 1, h, w)
reverse_x_axis=reverse_x_axis, # If True, obey radiologic convention (e.g., heart on right)
patch_size=patch_size, # Render patches of the DRR in series
renderer=renderer, # Rendering backend, either "siddon" or "trilinear"
).to(device)

# Ensure rotations are in radians
if rotations_degrees:
rotations = torch.deg2rad(rotations)

zero = torch.tensor([[0.0, 0.0, 0.0]], device=device)
pose1 = convert(
zero, translations, parameterization="euler_angles", convention="ZXY"
)
pose2 = convert(rotations, zero, parameterization="euler_angles", convention="ZXY")
pose = pose1.compose(pose2)

img = drr(pose, mask_to_channels=mask_to_channels)
return img
85 changes: 85 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import sys

import torch
import torch.nn as nn
from monai.networks.nets import UNet

from config import config
from preprocessing import get_dataloaders, get_datasets, get_transforms
from train import train, validate

sys.path.append("..")


# Define main function
def main():

# Get the transforms
transform = get_transforms(contrast_value=1000)

# Get the datasets
train_dataset, val_dataset = get_datasets(
root_dir="../data",
collection="HCC-TACE-Seg",
transform=transform,
download=True,
download_len=5,
seg_type="SEG",
val_frac=config.val_frac,
seed=config.seed,
)

# Get the dataloaders
train_loader, val_loader = get_dataloaders(
train_dataset, val_dataset, batch_size=1, num_workers=0
)

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the model
model = UNet(
in_channels=1,
out_channels=20,
spatial_dims=(512, 512, 96),
channels=5,
strides=1,
).to(device)

# Initialize the criterion
criterion = nn.CrossEntropyLoss()

# Initialize the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

# Initialize the best validation loss
best_val_loss = float("inf")

# Train the model
for epoch in range(config.num_epochs):
train_loss, train_acc, train_dice, train_iou = train(
train_loader, model, criterion, optimizer, device
)
val_loss, val_acc, val_dice, val_iou = validate(
val_loader, model, criterion, device
)

print(
f"Epoch {epoch+1}/{config.num_epochs} - "
f"Train Loss: {train_loss:.4f}, "
f"Train Acc: {train_acc:.4f}, "
f"Train Dice: {train_dice:.4f}, "
f"Train IoU: {train_iou:.4f}, "
f"Val Loss: {val_loss:.4f}, "
f"Val Acc: {val_acc:.4f}, "
f"Val Dice: {val_dice:.4f}, "
f"Val IoU: {val_iou:.4f}"
)

if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), "best_model.pth")


if __name__ == "__main__":
main()
Empty file added src/model.py
Empty file.
20 changes: 12 additions & 8 deletions src/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
DataStatsd,
EnsureChannelFirstd,
LoadImaged,
RandGridDistortiond,
Resized,
ResizeWithPadOrCropd,
ScaleIntensityRanged,
Expand All @@ -13,13 +14,14 @@
from src.custom_transforms import (
AddBackgroundChannel,
AddVesselContrast,
ConvertToSingleChannel,
IsolateArteries,
RemoveDualImage,
RemoveNecrosisChannel,
)


def get_transforms(resize_shape=[512, 512, 96]):
def get_transforms(resize_shape=[512, 512, 96], contrast_value=100):
"""
Create a composed transform for the data preprocessing of mask and image data
Expand All @@ -34,9 +36,9 @@ def get_transforms(resize_shape=[512, 512, 96]):
[
LoadImaged(reader="PydicomReader", keys=["image", "seg"]),
EnsureChannelFirstd(keys=["image", "seg"]),
ResizeWithPadOrCropd(keys=["image", "seg"], spatial_size=[512, 512, 64]),
# ResizeWithPadOrCropd(keys=["image", "seg"], spatial_size=[512, 512, 64]),
RemoveNecrosisChannel(keys=["seg"]),
Resized(keys=["image", "seg"], spatial_size=resize_shape),
# Resized(keys=["image", "seg"], spatial_size=resize_shape),
AddBackgroundChannel(keys=["seg"]),
# ScaleIntensityRanged(
# keys=["image"],
Expand All @@ -46,11 +48,13 @@ def get_transforms(resize_shape=[512, 512, 96]):
# b_max=230,
# clip=True,
# ),
AddVesselContrast(keys=["image", "seg"], contrast_value=200),
RemoveDualImage(keys=["image", "seg"]),
Resized(keys=["image", "seg"], spatial_size=resize_shape),
IsolateArteries(keys=["seg"]),
DataStatsd(keys=["image", "seg"], data_shape=True),
AddVesselContrast(keys=["image", "seg"], contrast_value=contrast_value),
# RemoveDualImage(keys=["image", "seg"]),
# Resized(keys=["image", "seg"], spatial_size=resize_shape),
# IsolateArteries(keys=["seg"]),
ConvertToSingleChannel(keys=["seg"]),
# DataStatsd(keys=["image", "seg"], data_shape=True),
# RandGridDistortiond(keys=["image", "seg"], prob=0.5),
],
lazy=False,
)
Expand Down
93 changes: 93 additions & 0 deletions src/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import sys

import numpy as np
import torch
import torch.nn as nn
import tqdm
from monai.networks.nets import UNet
from torchmetrics.functional.classification import (
accuracy,
dice,
multiclass_jaccard_index,
)

from preprocessing import get_dataloaders, get_datasets, get_transforms

sys.path.append("..")

from config import config


# Define the training loop
def train(train_dataloader, model, criterion, optimizer, device):
model.train()
train_loss = 0.0
train_acc = 0.0
train_dice = 0.0
train_iou = 0.0

for inputs, targets in tqdm(train_dataloader):
targets = targets.long().squeeze(dim=1)
inputs, targets = inputs.to(device), targets.to(device)

print(inputs.shape)

optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

train_loss += loss.detach()
outputs_max = torch.argmax(outputs, dim=1)

train_acc += accuracy(
outputs_max, targets, task="multiclass", num_classes=20, ignore_index=19
).detach()
train_dice += dice(outputs, targets, ignore_index=19).detach()
train_iou += multiclass_jaccard_index(
outputs, targets, num_classes=20, ignore_index=19
).detach()

num_batches = len(train_dataloader)
return (
(train_loss / num_batches).item(),
(train_acc / num_batches).item(),
(train_dice / num_batches).item(),
(train_iou / num_batches).item(),
)


# Define the validation loop
def validate(val_loader, model, criterion, device):
model.eval()
val_loss = 0.0
val_acc = 0.0
val_dice = 0.0
val_iou = 0.0

for inputs, targets in tqdm(val_loader):
targets = targets.long().squeeze(dim=1)
inputs, targets = inputs.to(device), targets.to(device)

outputs = model(inputs)
loss = criterion(outputs, targets)

val_loss += loss.detach()
outputs_max = torch.argmax(outputs, dim=1)

val_acc += accuracy(
outputs_max, targets, task="multiclass", num_classes=20, ignore_index=19
).detach()
val_dice += dice(outputs, targets, ignore_index=19).detach()
val_iou += multiclass_jaccard_index(
outputs, targets, num_classes=20, ignore_index=19
).detach()

num_batches = len(val_loader)
return (
(val_loss / num_batches).item(),
(val_acc / num_batches).item(),
(val_dice / num_batches).item(),
(val_iou / num_batches).item(),
)
25 changes: 25 additions & 0 deletions test/test_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import unittest

from monai.transforms import Compose

from src.preprocessing import get_transforms


class TestGetTransforms(unittest.TestCase):
def setUp(self):
self.transform = get_transforms()

def test_return_type(self):
self.assertIsInstance(self.transform, Compose)

def test_transform_length(self):
self.assertTrue(len(self.transform.transforms) > 0)

def test_transform_keys(self):
expected_keys = ["image", "seg"]
for transform in self.transform.transforms:
self.assertEqual(transform.keys, expected_keys)


if __name__ == "__main__":
unittest.main()

0 comments on commit f0ad338

Please sign in to comment.