Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mi2rl CheSS model #122

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ model = xrv.baseline_models.jfhealthcare.DenseNet()
# Official Stanford CheXpert model
model = xrv.baseline_models.chexpert.DenseNet(weights_zip="chexpert_weights.zip")

# CheSS: Chest X-Ray Pre-trained Model via Self-supervised Contrastive Learning
# Outputs a 2048 dimensional feature vector
model = xrv.baseline_models.mi2rl.CheSS()

```

Benchmarks of the modes are here: [BENCHMARKS.md](BENCHMARKS.md) and the performance of some of the models can be seen in this paper [arxiv.org/abs/2002.02497](https://arxiv.org/abs/2002.02497).
Expand Down
456 changes: 456 additions & 0 deletions scripts/xray_representations2.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions torchxrayvision/baseline_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import jfhealthcare
from . import chexpert
from . import chestx_det
from . import mi2rl
85 changes: 85 additions & 0 deletions torchxrayvision/baseline_models/mi2rl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import sys, os
thisfolder = os.path.dirname(__file__)
sys.path.insert(0, thisfolder)
import json
import pathlib
import torch
import torch.nn as nn
import torchvision
import chess_resnet
import torchxrayvision as xrv


class CheSS(nn.Module):
"""CheSS: Chest X-Ray Pre-trained Model via Self-supervised Contrastive Learning

Paper: https://link.springer.com/article/10.1007/s10278-023-00782-4
Source: https://github.com/mi2rl/CheSS
License: Apache-2.0 license
"""
def __init__(self):
super().__init__()

self.model = chess_resnet.resnet50(num_classes=128)

url = "https://github.com/mlmed/torchxrayvision/releases/download/v1/mi2rl-chess-resnet.pth"

weights_filename = os.path.basename(url)
weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data"))
self.weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename))

if not os.path.isfile(self.weights_filename_local):
print("Downloading weights...")
print("If this fails you can run `wget {} -O {}`".format(url, self.weights_filename_local))
pathlib.Path(weights_storage_folder).mkdir(parents=True, exist_ok=True)
xrv.utils.download(url, self.weights_filename_local)

try:
state_dict = torch.load(self.weights_filename_local, map_location="cpu")
self.model.load_state_dict(state_dict)
except Exception as e:
print("Loading failure. Check weights file:", self.weights_filename_local)
raise (e)

self.upsample = nn.Upsample(size=(512, 512), mode='bilinear', align_corners=False)

self.normalize = torchvision.transforms.Normalize(0.658, 0.221)


def transform_from_xrv(self, x):
"""Following https://github.com/mi2rl/CheSS/blob/main/downstream/classification/datasets.py"""

x = self.upsample(x)

x -= x.min()
x /= (x.max() - x.min())
x *= 255
x = self.normalize(x)

return x


def features(self, x):

x = self.transform_from_xrv(x)

x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)

x = self.model.layer1(x)
x = self.model.layer2(x)
x = self.model.layer3(x)
x = self.model.layer4(x)

x = self.model.avgpool(x)
x = torch.flatten(x, 1)
return x


def forward(self, x):
return self.features(x)

def __repr__(self):
return "mi2rl-CheSS"
222 changes: 222 additions & 0 deletions torchxrayvision/baseline_models/mi2rl/chess_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import time

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out


class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out

class ResNet(nn.Module):

def __init__(self, block, layers, num_classes=512, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer

self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group

self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)


def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion

for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))

return nn.Sequential(*layers)

def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)

return x

def forward(self, x):
return self._forward_impl(x)

def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
return model

def resnet50(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)

def resnet152(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)
19 changes: 19 additions & 0 deletions torchxrayvision/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys
import requests

def in_notebook():
try:
Expand All @@ -11,4 +13,21 @@ def in_notebook():
return True


# from here https://sumit-ghosh.com/articles/python-download-progress-bar/
def download(url: str, filename: str):
with open(filename, 'wb') as f:
response = requests.get(url, stream=True)
total = response.headers.get('content-length')

if total is None:
f.write(response.content)
else:
downloaded = 0
total = int(total)
for data in response.iter_content(chunk_size=max(int(total / 1000), 1024 * 1024)):
downloaded += len(data)
f.write(data)
done = int(50 * downloaded / total)
sys.stdout.write('\r[{}{}]'.format('█' * done, '.' * (50 - done)))
sys.stdout.flush()
sys.stdout.write('\n')