Skip to content

Commit

Permalink
WSI reader (#1548)
Browse files Browse the repository at this point in the history
* Implement CuImageReader and OpenSlideReader

Signed-off-by: Behrooz <[email protected]>

* Add unittests for CuImageReader

Signed-off-by: Behrooz <[email protected]>

* Add unittests for OpenSlideReader

Signed-off-by: Behrooz <[email protected]>

* Sort imports

Signed-off-by: Behrooz <[email protected]>

* Add correct boundaries

Signed-off-by: Behrooz <[email protected]>

* Add test cases for reading patches on a grid for CuImage

Signed-off-by: Behrooz <[email protected]>

* Add patch whole slide imaging dataset for pathology

Signed-off-by: Behrooz <[email protected]>

* Add test case for read patches for OpenSlide

Signed-off-by: Behrooz <[email protected]>

* flake8 and few minor changes

Signed-off-by: Behrooz <[email protected]>

* black

Signed-off-by: Behrooz <[email protected]>

* flake8

Signed-off-by: Behrooz <[email protected]>

* Add kwargs to CuImageReader and OpenSlideReader's read method

Signed-off-by: Behrooz <[email protected]>

* Change the type hint from np.dtype to DTypeLike

Signed-off-by: Behrooz <[email protected]>

* Fix a bug

Signed-off-by: Behrooz <[email protected]>

* Implement WSIReader and unittests

Signed-off-by: Behrooz <[email protected]>

* Minor updates

Signed-off-by: Behrooz <[email protected]>

* Fix few typing issues

Signed-off-by: Behrooz <[email protected]>

* Revert datasets

Signed-off-by: Behrooz <[email protected]>

* Add shape property to openslide image object
Reverse size to be compatible with output size (hxw)

Signed-off-by: Behrooz <[email protected]>

* Add untittest for loading the whole image
Reverse the size accroding to the WSIReader

Signed-off-by: Behrooz <[email protected]>

* Update the whole image size

Signed-off-by: Behrooz <[email protected]>

* Remove optional size

Signed-off-by: Behrooz <[email protected]>

* Remove optional dtype

Signed-off-by: Behrooz <[email protected]>

* Remove _get_spatial_shape return type

Signed-off-by: Behrooz <[email protected]>

* Reverse the orders of dimensions of `location`
to be compatible with image shape

Signed-off-by: Behrooz <[email protected]>

* Change test cases to use smaller image and revese location's dimensions

Signed-off-by: Behrooz <[email protected]>

* Replace the test TIFF and some upgrades

Signed-off-by: Behrooz <[email protected]>

* Update dependencies for OpenSlide

Signed-off-by: Behrooz <[email protected]>

* Update unittests for OpenSlide and CuImage

Signed-off-by: Behrooz <[email protected]>

* Fix openslide dependency

Signed-off-by: Behrooz <[email protected]>

* Fix doc dependencies

Signed-off-by: Behrooz <[email protected]>

* Minor changes

Signed-off-by: Behrooz <[email protected]>

* Few variable name changes

Signed-off-by: Behrooz <[email protected]>

* Add EnsureChannelFirst

Signed-off-by: Behrooz <[email protected]>

* Add metadata to WSIReader

Signed-off-by: Behrooz <[email protected]>
  • Loading branch information
bhashemian authored Mar 5, 2021
1 parent 75b5772 commit 889c9f9
Show file tree
Hide file tree
Showing 8 changed files with 373 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pytorch-ignite==0.4.2
numpy>=1.17
itk>=5.0
nibabel
openslide-python==1.1.2
parameterized
scikit-image>=0.14.2
tensorboard
Expand Down
4 changes: 4 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ PILReader
.. autoclass:: PILReader
:members:

WSIReader
~~~~~~~~~
.. autoclass:: WSIReader
:members:

Nifti format handling
---------------------
Expand Down
2 changes: 1 addition & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .decathlon_datalist import load_decathlon_datalist, load_decathlon_properties
from .grid_dataset import GridPatchDataset, PatchDataset
from .image_dataset import ImageDataset
from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader
from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader
from .iterable_dataset import IterableDataset
from .nifti_saver import NiftiSaver
from .nifti_writer import write_nifti
Expand Down
161 changes: 157 additions & 4 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,31 @@

from monai.config import DtypeLike, KeysCollection
from monai.data.utils import correct_nifti_header_if_necessary
from monai.transforms.utility.array import EnsureChannelFirst
from monai.utils import ensure_tuple, optional_import

from .utils import is_supported_format

if TYPE_CHECKING:
import cuimage
import itk # type: ignore
import nibabel as nib
import openslide
from itk import Image # type: ignore
from nibabel.nifti1 import Nifti1Image
from PIL import Image as PILImage

has_itk = has_nib = has_pil = True
has_itk = has_nib = has_pil = has_cux = has_osl = True
else:
itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
Image, _ = optional_import("itk", allow_namespace_pkg=True, name="Image")
nib, has_nib = optional_import("nibabel")
Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image")
PILImage, has_pil = optional_import("PIL.Image")
cuimage, has_cux = optional_import("cuimage")
openslide, has_osl = optional_import("openslide")

__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader"]
__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "WSIReader"]


class ImageReader(ABC):
Expand Down Expand Up @@ -264,10 +269,10 @@ def _get_affine(self, img) -> np.ndarray:
origin = np.asarray(img.GetOrigin())

direction = np.asarray(direction)
affine = np.eye(direction.shape[0] + 1)
affine: np.ndarray = np.eye(direction.shape[0] + 1)
affine[(slice(-1), slice(-1))] = direction @ np.diag(spacing)
affine[(slice(-1), -1)] = origin
return np.asarray(affine)
return affine

def _get_spatial_shape(self, img) -> np.ndarray:
"""
Expand Down Expand Up @@ -626,3 +631,151 @@ def _get_spatial_shape(self, img) -> np.ndarray:
"""
# the img data should have no channel dim or the last dim is channel
return np.asarray((img.width, img.height))


class WSIReader(ImageReader):
"""
Read whole slide imaging and extract patches
"""

def __init__(self, reader_lib: str = "cuClaraImage"):
super().__init__()
self.reader_lib = reader_lib.lower()
if self.reader_lib == "openslide":
self.wsi_reader = openslide.OpenSlide
print("> OpenSlide is being used.")
elif self.reader_lib == "cuclaraimage":
self.wsi_reader = cuimage.CuImage
print("> CuImage is being used.")
else:
raise ValueError('`reader_lib` should be either "cuClaraImage" or "OpenSlide"')

def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool:
"""
Verify whether the specified file or files format is supported by WSI reader.
Args:
filename: file name or a list of file names to read.
if a list of files, verify all the suffixes.
"""
return is_supported_format(filename, ["tif", "tiff"])

def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs):
"""
Read image data from specified file or files.
Note that the returned object is CuImage or list of CuImage objects.
Args:
data: file name or a list of file names to read.
"""
img_: List = []

filenames: Sequence[str] = ensure_tuple(data)
for name in filenames:
img = self.wsi_reader(name)
if self.reader_lib == "openslide":
img.shape = (img.dimensions[1], img.dimensions[0], 3)
img_.append(img)

return img_ if len(filenames) > 1 else img_[0]

def get_data(
self,
img,
location: Tuple[int, int] = (0, 0),
size: Optional[Tuple[int, int]] = None,
level: int = 0,
dtype: DtypeLike = np.uint8,
grid_shape: Tuple[int, int] = (1, 1),
patch_size: Optional[int] = None,
):
"""
Extract regions as numpy array from WSI image and return them.
Args:
img: a WSIReader image object loaded from a file, or list of CuImage objects
location: (x_min, y_min) tuple giving the top left pixel in the level 0 reference frame,
or list of tuples (default=(0, 0))
size: (height, width) tuple giving the region size, or list of tuples (default to full image size)
This is the size of image at the given level (`level`)
level: the level number, or list of level numbers (default=0)
dtype: the data type of output image
grid_shape: (row, columns) tuple define a grid to extract patches on that
patch_size: (heigsht, width) the size of extracted patches at the given level
"""
if size is None:
if location == (0, 0):
# the maximum size is set to WxH
size = (img.shape[0] // (2 ** level), img.shape[1] // (2 ** level))
print(f"Reading the whole image at level={level} with shape={size}")
else:
raise ValueError("Size need to be provided to extract the region!")

region = self._extract_region(img, location=location, size=size, level=level, dtype=dtype)

metadata: Dict = {}
metadata["spatial_shape"] = size
metadata["original_channel_dim"] = -1
region = EnsureChannelFirst()(region, metadata)

if patch_size is None:
patches = region
else:
patches = self._extract_patches(
region, patch_size=(patch_size, patch_size), grid_shape=grid_shape, dtype=dtype
)

return patches, metadata

def _extract_region(
self,
img_obj,
size: Tuple[int, int],
location: Tuple[int, int] = (0, 0),
level: int = 0,
dtype: DtypeLike = np.uint8,
):
# reverse the order of dimensions for size and location to be compatible with image shape
size = size[::-1]
location = location[::-1]
region = img_obj.read_region(location=location, size=size, level=level)
if self.reader_lib == "openslide":
region = region.convert("RGB")
# convert to numpy
region = np.asarray(region, dtype=dtype)

return region

def _extract_patches(
self,
region: np.ndarray,
grid_shape: Tuple[int, int] = (1, 1),
patch_size: Optional[Tuple[int, int]] = None,
dtype: DtypeLike = np.uint8,
):
if patch_size is None and grid_shape == (1, 1):
return region

n_patches = grid_shape[0] * grid_shape[1]
region_size = region.shape[1:]

if patch_size is None:
patch_size = (region_size[0] // grid_shape[0], region_size[1] // grid_shape[1])

# split the region into patches on the grid and center crop them to patch size
flat_patch_grid = np.zeros((n_patches, 3, patch_size[0], patch_size[1]), dtype=dtype)
start_points = [
np.round(region_size[i] * (0.5 + np.arange(grid_shape[i])) / grid_shape[i] - patch_size[i] / 2).astype(int)
for i in range(2)
]
idx = 0
for y_start in start_points[1]:
for x_start in start_points[0]:
x_end = x_start + patch_size[0]
y_end = y_start + patch_size[1]
flat_patch_grid[idx] = region[:, x_start:x_end, y_start:y_end]
idx += 1

return flat_patch_grid
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ Sphinx==3.3.0
recommonmark==0.6.0
sphinx-autodoc-typehints==1.11.1
sphinx-rtd-theme==0.5.0
openslide-python==1.1.2
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ all =
torchvision
itk>=5.0
tqdm>=4.47.0
openslide-python==1.1.2
nibabel =
nibabel
skimage =
Expand All @@ -54,6 +55,8 @@ lmdb =
lmdb
psutil =
psutil
openslide =
openslide-python==1.1.2

[flake8]
select = B,C,E,F,N,P,T4,W,B9
Expand Down
103 changes: 103 additions & 0 deletions tests/test_cuimage_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os
import unittest
from unittest import skipUnless
from urllib import request

import numpy as np
from numpy.testing import assert_array_equal
from parameterized import parameterized

from monai.data.image_reader import WSIReader
from monai.utils import optional_import

_, has_cui = optional_import("cuimage")


FILE_URL = "http://openslide.cs.cmu.edu/download/openslide-testdata/Generic-TIFF/CMU-1.tiff"
HEIGHT = 32914
WIDTH = 46000

TEST_CASE_0 = [FILE_URL, (3, HEIGHT, WIDTH)]

TEST_CASE_1 = [
FILE_URL,
{"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0},
np.array([[[246], [246]], [[246], [246]], [[246], [246]]]),
]

TEST_CASE_2 = [
FILE_URL,
{"location": (0, 0), "size": (2, 1), "level": 2},
np.array([[[239], [239]], [[239], [239]], [[239], [239]]]),
]

TEST_CASE_3 = [
FILE_URL,
{
"location": (0, 0),
"size": (8, 8),
"level": 2,
"grid_shape": (2, 1),
"patch_size": 2,
},
np.array(
[
[[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[239, 239], [239, 239]]],
[[[242, 242], [242, 243]], [[242, 242], [242, 243]], [[242, 242], [242, 243]]],
]
),
]

TEST_CASE_4 = [
FILE_URL,
{
"location": (0, 0),
"size": (8, 8),
"level": 2,
"grid_shape": (2, 1),
"patch_size": 1,
},
np.array([[[[239]], [[239]], [[239]]], [[[243]], [[243]], [[243]]]]),
]


class TestCuClaraImageReader(unittest.TestCase):
@parameterized.expand([TEST_CASE_0])
@skipUnless(has_cui, "Requires CuClaraImage")
def test_read_whole_image(self, file_url, expected_shape):
filename = self.camelyon_data_download(file_url)
reader = WSIReader("CuClaraImage")
img_obj = reader.read(filename)
img = reader.get_data(img_obj)[0]
self.assertTupleEqual(img.shape, expected_shape)

@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
@skipUnless(has_cui, "Requires CuClaraImage")
def test_read_region(self, file_url, patch_info, expected_img):
filename = self.camelyon_data_download(file_url)
reader = WSIReader("CuClaraImage")
img_obj = reader.read(filename)
img = reader.get_data(img_obj, **patch_info)[0]
self.assertTupleEqual(img.shape, expected_img.shape)
self.assertIsNone(assert_array_equal(img, expected_img))

@parameterized.expand([TEST_CASE_3, TEST_CASE_4])
@skipUnless(has_cui, "Requires CuClaraImage")
def test_read_patches(self, file_url, patch_info, expected_img):
filename = self.camelyon_data_download(file_url)
reader = WSIReader("CuClaraImage")
img_obj = reader.read(filename)
img = reader.get_data(img_obj, **patch_info)[0]
self.assertTupleEqual(img.shape, expected_img.shape)
self.assertIsNone(assert_array_equal(img, expected_img))

def camelyon_data_download(self, file_url):
filename = os.path.basename(file_url)
if not os.path.exists(filename):
print(f"Test image [{filename}] does not exist. Downloading...")
request.urlretrieve(file_url, filename)
return filename


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 889c9f9

Please sign in to comment.