-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #15 from LucianoDeben/feature/DDR
Add complete DRR functionality with example notebooks
- Loading branch information
Showing
16 changed files
with
1,414 additions
and
641 deletions.
There are no files selected for viewing
Submodule DiffDRR
added at
8d3248
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
class Config: | ||
learning_rate = 0.001 | ||
batch_size = 4 | ||
num_epochs = 10 | ||
val_frac = 0.2 | ||
seed = 42 | ||
|
||
|
||
config = Config() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Import libraries | ||
import sys | ||
|
||
sys.path.append("..") | ||
|
||
import torch | ||
|
||
from diffdrr.data import read | ||
from diffdrr.drr import DRR | ||
from diffdrr.pose import convert | ||
|
||
|
||
def create_drr( | ||
volume, | ||
segmentation, | ||
bone_attenuation_multiplier=5.0, | ||
sdd=1020, | ||
height=200, | ||
width=200, | ||
delx=2.0, | ||
dely=2.0, | ||
x0=0, | ||
y0=0, | ||
p_subsample=None, | ||
reshape=True, | ||
reverse_x_axis=True, | ||
patch_size=None, | ||
renderer="siddon", | ||
rotations=torch.tensor([[0.0, 0.0, 0.0]]), | ||
rotations_degrees=True, | ||
translations=torch.tensor([[0.0, 850.0, 0.0]]), | ||
mask_to_channels=True, | ||
device="cpu", | ||
): | ||
|
||
# Read the image and segmentation subject | ||
subject = read( | ||
tensor=volume, | ||
label_tensor=segmentation, | ||
orientation="AP", | ||
bone_attenuation_multiplier=bone_attenuation_multiplier, | ||
) | ||
|
||
# Create a DRR object | ||
drr = DRR( | ||
subject, # A torchio.Subject object storing the CT volume, origin, and voxel spacing | ||
sdd=sdd, # Source-to-detector distance (i.e., the C-arm's focal length) | ||
height=height, # Height of the DRR (if width is not seperately provided, the generated image is square) | ||
width=width, # Width of the DRR | ||
delx=delx, # Pixel spacing (in mm) | ||
dely=dely, # Pixel spacing (in mm) | ||
x0=x0, # # Principal point X-offset | ||
y0=y0, # Principal point Y-offset | ||
p_subsample=p_subsample, # Proportion of pixels to randomly subsample | ||
reshape=reshape, # Return DRR with shape (b, 1, h, w) | ||
reverse_x_axis=reverse_x_axis, # If True, obey radiologic convention (e.g., heart on right) | ||
patch_size=patch_size, # Render patches of the DRR in series | ||
renderer=renderer, # Rendering backend, either "siddon" or "trilinear" | ||
).to(device) | ||
|
||
# Ensure rotations are in radians | ||
if rotations_degrees: | ||
rotations = torch.deg2rad(rotations) | ||
|
||
zero = torch.tensor([[0.0, 0.0, 0.0]], device=device) | ||
pose1 = convert( | ||
zero, translations, parameterization="euler_angles", convention="ZXY" | ||
) | ||
pose2 = convert(rotations, zero, parameterization="euler_angles", convention="ZXY") | ||
pose = pose1.compose(pose2) | ||
|
||
img = drr(pose, mask_to_channels=mask_to_channels) | ||
return img |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import sys | ||
|
||
import torch | ||
import torch.nn as nn | ||
from monai.networks.nets import UNet | ||
|
||
from config import config | ||
from preprocessing import get_dataloaders, get_datasets, get_transforms | ||
from train import train, validate | ||
|
||
sys.path.append("..") | ||
|
||
|
||
# Define main function | ||
def main(): | ||
|
||
# Get the transforms | ||
transform = get_transforms(contrast_value=1000) | ||
|
||
# Get the datasets | ||
train_dataset, val_dataset = get_datasets( | ||
root_dir="../data", | ||
collection="HCC-TACE-Seg", | ||
transform=transform, | ||
download=True, | ||
download_len=5, | ||
seg_type="SEG", | ||
val_frac=config.val_frac, | ||
seed=config.seed, | ||
) | ||
|
||
# Get the dataloaders | ||
train_loader, val_loader = get_dataloaders( | ||
train_dataset, val_dataset, batch_size=1, num_workers=0 | ||
) | ||
|
||
# Set the device | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
# Initialize the model | ||
model = UNet( | ||
in_channels=1, | ||
out_channels=20, | ||
spatial_dims=(512, 512, 96), | ||
channels=5, | ||
strides=1, | ||
).to(device) | ||
|
||
# Initialize the criterion | ||
criterion = nn.CrossEntropyLoss() | ||
|
||
# Initialize the optimizer | ||
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) | ||
|
||
# Initialize the best validation loss | ||
best_val_loss = float("inf") | ||
|
||
# Train the model | ||
for epoch in range(config.num_epochs): | ||
train_loss, train_acc, train_dice, train_iou = train( | ||
train_loader, model, criterion, optimizer, device | ||
) | ||
val_loss, val_acc, val_dice, val_iou = validate( | ||
val_loader, model, criterion, device | ||
) | ||
|
||
print( | ||
f"Epoch {epoch+1}/{config.num_epochs} - " | ||
f"Train Loss: {train_loss:.4f}, " | ||
f"Train Acc: {train_acc:.4f}, " | ||
f"Train Dice: {train_dice:.4f}, " | ||
f"Train IoU: {train_iou:.4f}, " | ||
f"Val Loss: {val_loss:.4f}, " | ||
f"Val Acc: {val_acc:.4f}, " | ||
f"Val Dice: {val_dice:.4f}, " | ||
f"Val IoU: {val_iou:.4f}" | ||
) | ||
|
||
if val_loss < best_val_loss: | ||
best_val_loss = val_loss | ||
torch.save(model.state_dict(), "best_model.pth") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import sys | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import tqdm | ||
from monai.networks.nets import UNet | ||
from torchmetrics.functional.classification import ( | ||
accuracy, | ||
dice, | ||
multiclass_jaccard_index, | ||
) | ||
|
||
from preprocessing import get_dataloaders, get_datasets, get_transforms | ||
|
||
sys.path.append("..") | ||
|
||
from config import config | ||
|
||
|
||
# Define the training loop | ||
def train(train_dataloader, model, criterion, optimizer, device): | ||
model.train() | ||
train_loss = 0.0 | ||
train_acc = 0.0 | ||
train_dice = 0.0 | ||
train_iou = 0.0 | ||
|
||
for inputs, targets in tqdm(train_dataloader): | ||
targets = targets.long().squeeze(dim=1) | ||
inputs, targets = inputs.to(device), targets.to(device) | ||
|
||
print(inputs.shape) | ||
|
||
optimizer.zero_grad() | ||
outputs = model(inputs) | ||
loss = criterion(outputs, targets) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
train_loss += loss.detach() | ||
outputs_max = torch.argmax(outputs, dim=1) | ||
|
||
train_acc += accuracy( | ||
outputs_max, targets, task="multiclass", num_classes=20, ignore_index=19 | ||
).detach() | ||
train_dice += dice(outputs, targets, ignore_index=19).detach() | ||
train_iou += multiclass_jaccard_index( | ||
outputs, targets, num_classes=20, ignore_index=19 | ||
).detach() | ||
|
||
num_batches = len(train_dataloader) | ||
return ( | ||
(train_loss / num_batches).item(), | ||
(train_acc / num_batches).item(), | ||
(train_dice / num_batches).item(), | ||
(train_iou / num_batches).item(), | ||
) | ||
|
||
|
||
# Define the validation loop | ||
def validate(val_loader, model, criterion, device): | ||
model.eval() | ||
val_loss = 0.0 | ||
val_acc = 0.0 | ||
val_dice = 0.0 | ||
val_iou = 0.0 | ||
|
||
for inputs, targets in tqdm(val_loader): | ||
targets = targets.long().squeeze(dim=1) | ||
inputs, targets = inputs.to(device), targets.to(device) | ||
|
||
outputs = model(inputs) | ||
loss = criterion(outputs, targets) | ||
|
||
val_loss += loss.detach() | ||
outputs_max = torch.argmax(outputs, dim=1) | ||
|
||
val_acc += accuracy( | ||
outputs_max, targets, task="multiclass", num_classes=20, ignore_index=19 | ||
).detach() | ||
val_dice += dice(outputs, targets, ignore_index=19).detach() | ||
val_iou += multiclass_jaccard_index( | ||
outputs, targets, num_classes=20, ignore_index=19 | ||
).detach() | ||
|
||
num_batches = len(val_loader) | ||
return ( | ||
(val_loss / num_batches).item(), | ||
(val_acc / num_batches).item(), | ||
(val_dice / num_batches).item(), | ||
(val_iou / num_batches).item(), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import unittest | ||
|
||
from monai.transforms import Compose | ||
|
||
from src.preprocessing import get_transforms | ||
|
||
|
||
class TestGetTransforms(unittest.TestCase): | ||
def setUp(self): | ||
self.transform = get_transforms() | ||
|
||
def test_return_type(self): | ||
self.assertIsInstance(self.transform, Compose) | ||
|
||
def test_transform_length(self): | ||
self.assertTrue(len(self.transform.transforms) > 0) | ||
|
||
def test_transform_keys(self): | ||
expected_keys = ["image", "seg"] | ||
for transform in self.transform.transforms: | ||
self.assertEqual(transform.keys, expected_keys) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |