Skip to content

Commit

Permalink
revamp README example; add documentation to main function; make segme…
Browse files Browse the repository at this point in the history
…nt manual default
  • Loading branch information
afrendeiro committed Jan 16, 2024
1 parent e9248ef commit 5f9956b
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 37 deletions.
56 changes: 29 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,58 +26,60 @@ Note that the package uses setuptols-scm for version control and therefore the i

## Usage

```python
import requests
import pandas as pd
import torch
import tqdm
This package is meant for both interactive use and for use in a pipeline at scale.
By default actions do not return anything, but instead save the results to disk in files relative to the slide file.

All major functions have sensible defaults but allow for customization.
Please check the docstring of each function for more information.

```python
from wsi_core import WholeSlideImage
from wsi_core.utils import Path

# Get example slide image
slide_name = "GTEX-1117F-1126"
slide_file = Path(f"{slide_name}.svs")
slide_file = Path("GTEX-12ZZW-2726.svs")
if not slide_file.exists():
url = f"https://brd.nci.nih.gov/brd/imagedownload/{slide_name}"
import requests
url = f"https://brd.nci.nih.gov/brd/imagedownload/{slide_file.stem}"
with open(slide_file, "wb") as handle:
req = requests.get(url)
handle.write(req.content)

# Instantiate slide class
# Instantiate slide object
slide = WholeSlideImage(slide_file)

# Segment tissue
url = "https://raw.githubusercontent.com/mahmoodlab/CLAM/master/presets/bwh_biopsy.csv"
params = pd.read_csv(url).squeeze()
slide.segmentTissue(seg_level=2, filter_params=params.to_dict())
slide.saveSegmentation()
# # alternatively, simply:
# Instantiate slide object
slide = WholeSlideImage(slide_file, attributes=dict(donor="GTEX-12ZZW"))

# Segment tissue (segmentation mask is stored as polygons in slide.contours_tissue)
slide.segment()

# Visualize segmentation
slide.initSegmentation()
slide.visWSI(vis_level=2).save(f"{slide_name}.segmentation.png")
# Visualize segmentation (PNG file is saved in same directory as slide_file)
slide.plot_segmentation()

# Generate coordinates for tiling in h5 file (highest resolution, non-overlapping tiles)
# # Only store coordinates in hdf5 file:
slide.process_contours('.', patch_level=0, patch_size=224, step_size=224)
# # alternatively, simply:
slide.tile()
# # Store coordinates and images in hdf5 file:
slide.createPatches_bag_hdf5(patch_level=0, patch_size=224, step_size=224)

# Get coordinates
# Get coordinates (from h5 file)
slide.get_tile_coordinates()
# Get images
slide.get_tile_images()
# Get single tile using lower level OpenSlide handle

# Get image of single tile using lower level OpenSlide handle (`wsi` object)
slide.wsi.read_region((1_000, 2_000), level=0, size=(224, 224))

# Get tile images for all tiles (as a generator)
images = slide.get_tile_images()
for img in images:
...

# Save tile images to disk as individual jpg files
slide.save_tile_images(output_dir=slide_file.parent / (slide_file.stem + "_tiles"))

# Use in a torch dataloader
loader = slide.as_data_loader()

# Extract features
import torch
from tqdm import tqdm
model = torch.hub.load("pytorch/vision", "resnet50", pretrained=True)
for count, (batch, coords) in tqdm(enumerate(loader), total=len(loader)):
with torch.no_grad():
Expand Down
131 changes: 121 additions & 10 deletions wsi_core/WholeSlideImage.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,24 @@ class WholeSlideImage(object):
def __init__(
self,
path: Path | _Path | str,
*,
attributes: tp.Optional[dict[str, tp.Any]] = None,
mask_file: Path | None = None,
hdf5_file: Path | None = None,
):
"""
Args:
path (str): fullpath to WSI file
attributes
WholeSlideImage object for handling WSI.
Parameters
----------
path: Path
Path to WSI file.
attributes: dict[str, tp.Any]
Optional dictionary with attributes to store in the object.
mask_file: Path
Path to file used to save segmentation. Default is `path.with_suffix(".segmentation.pickle")`.
hdf5_file: Path
Path to file used to save tile coordinates (and images). Default is `path.with_suffix(".h5")`.
"""
if not isinstance(path, Path):
path = Path(path)
Expand Down Expand Up @@ -760,8 +770,35 @@ def segment_tissue_manual(self, level: int | None = None, color_space: str = "RG
def segment(
self,
params: tp.Optional[dict[str, tp.Any]] = None,
method: str = "CLAM",
method: str = "manual",
) -> None:
"""
Segment the WSI for tissue and background.
Segmentations are saved as a list of contours and holes in the
`contours_tissue` and `holes_tissue` attributes.
This object is then saved to disk as a pickle file, by default
in the same directory as the WSI with the same name but with a
`.segmentation.pickle` suffix.
A visualization of the segmentation will also be plotted by
calling `plot_segmentation` and saved as a PNG file ( default
in the same directory as the WSI with the same name but with a
`.segmentation.png` suffix).
Parameters
----------
params: dict[str, tp.Any]
Parameters for the segmentation method.
method: str
Segmentation method to use. Either "manual" or "CLAM".
The CLAM method uses the parameters given in `params` or
the default parameters (bwh_biopsy) if `params` is None.
Returns
-------
None
"""
assert method in ["manual", "CLAM"], f"Unknown segmentation method: {method}"
if method == "manual":
self.segment_tissue_manual(**(params or {}))
Expand Down Expand Up @@ -844,6 +881,22 @@ def segment(
# return fig

def plot_segmentation(self, output_file: tp.Optional[Path] = None) -> None:
"""
Plot the segmentation of the WSI.
This plot is an overlay of a low resolution image of the WSI and the
contours of the tissue and holes.
Parameters
----------
output_file: Path
Path to save the plot to. If None, save to
`self.path.with_suffix(".segmentation.png")`.
Returns
-------
None
"""
if output_file is None:
output_file = self.path.with_suffix(".segmentation.png")

Expand All @@ -863,11 +916,12 @@ def tile(
Parameters
----------
patch_level: int
Level to extract patches from.
WSI level to extract patches from. Default is 0, which a convention
for highest resolution, but not always true.
patch_size: int
Size of patches to extract.
Size of patches to extract in pixels.
step_size: int
Step size between patches.
Step size between patches in pixels.
contour_subset: list[int]
1-based index of which contours to use. If None, use all contours.
Expand Down Expand Up @@ -900,7 +954,22 @@ def has_tile_images(self):
with h5py.File(self.hdf5_file, "r") as h5:
return "imgs" in h5

def get_tile_coordinates(self, hdf5_file: Path | None = None):
def get_tile_coordinates(self, hdf5_file: Path | None = None) -> np.ndarray:
"""
Retrieve coordinates of tiles from HDF5 file.
By default uses the `self.hdf5_file` attribute, but can be overridden.
Parameters
----------
hdf5_file: Path
Path to HDF5 file containing tile coordinates.
Returns
-------
np.ndarray
Array of tile coordinates with shape (N, 2).
"""
if hdf5_file is None:
hdf5_file = self.hdf5_file # or self.tile_h5
with h5py.File(hdf5_file, "r") as h5:
Expand All @@ -919,7 +988,23 @@ def get_tile_images(
self,
hdf5_file: Path | None = None,
as_generator: bool = True,
):
) -> tp.Generator[np.ndarray, None, None] | np.ndarray:
"""
Get tile images from HDF5 file.
By default it returns a generator, but can be overridden to return all as a array with batch dimension.
By default uses the `self.hdf5_file` attribute, but can be overridden.
Parameters
----------
hdf5_file: Path
Path to HDF5 file containing tile images.
Returns
-------
np.ndarray
Array of tile images with shape (N, 3, H, W).
"""
if hdf5_file is None:
hdf5_file = self.hdf5_file # or self.tile_h5

Expand Down Expand Up @@ -959,16 +1044,43 @@ def save_tile_images(
n: int | None = None,
frac: float = 1.0,
):
"""
Save tile images as individual files to disk.
Parameters
----------
output_dir: Path
Directory to save tile images to.
format: str
File format to save images as.
attributes: bool
Whether to include attributes in filename.
n: int
Number of tiles to save. Default is to save all.
frac: float
Fraction of tiles to save. Default is to save all.
Returns
-------
None
"""
import pandas as pd

if n is not None:
assert frac is None, "Only one of `n` or `frac` can be used."
if frac is not None:
assert n is None, "Only one of `n` or `frac` can be used."

output_dir.mkdir(exist_ok=True, parents=True)

_attributes = {}
if attributes:
_attributes = self.attributes if self.attributes is not None else {}
output_prefix = output_dir / (
self.name + ("." + ".".join(_attributes.values()))
)
else:
output_prefix = output_dir / self.name

hdf5_file = self.hdf5_file # or self.tile_h5
level, size = self.get_tile_coordinate_level_size(hdf5_file)
Expand All @@ -981,7 +1093,6 @@ def save_tile_images(

sel = pd.Series(range(nc)).sample(frac=frac, n=n).values

output_prefix = output_dir / (self.name + ("." + ".".join(_attributes.values())))
for coord in coords[sel]:
# Output in the form of: slide_name.attr[0].attr[1].attr[n].x.y.format
fp = output_prefix + f".{coord[0]}.{coord[1]}.{format}"
Expand Down

0 comments on commit 5f9956b

Please sign in to comment.