Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ROIDataset class taking a geopandas dataframe as input #48

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions malpolon/data/datasets/torchgeo_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import numpy as np
import pandas as pd
import rasterio
import geopandas as gpd
import pyproj
from pyproj import CRS, Transformer
from torchgeo.datasets import BoundingBox, RasterDataset
Expand All @@ -26,6 +28,17 @@
ALL_NORTHERN_EPSG_CODES = list(range(32601, 32662))
EUROPE_EPSG_CODE = [3035]

def intersects_with_img(roi, file_list):
res = False
for file in file_list:
with rasterio.open(file) as ds :
tf = ds.meta.copy()['transform']
bounds = (tf[2], ds.width*tf[0]+tf[2], ds.height*tf[4]+tf[5], tf[5])
if (roi.minx>bounds[0]) & (roi.miny>bounds[2]) & (roi.maxx<bounds[1]) & (roi.maxy<bounds[3]):
res = True
break
return res


class RasterTorchGeoDataset(RasterDataset):
"""Generic torchgeo based raster datasets.
Expand Down Expand Up @@ -473,3 +486,119 @@ def __getitem__(
if self.transforms_data is not None:
sample = self.transforms_data(sample)
return sample


class ROIDataset(RasterDataset):
filename_glob = "*"
filename_regex = ".*"
date_format = "%Y%m%d"
is_image = True
separate_files = False
all_bands: List[str] = []
rgb_bands: List[str] = []
cmap: Dict[int, Tuple[int, int, int, int]] = {}

def __init__(
self,
root: str,
gdf: gpd.GeoDataFrame,
target_var: str,
target_indexes: Optional[List] = None,
size: Union[Tuple[float, float], float] = 1,
units: Units = Units.PIXELS,
normalize_target_var=False,
crs: Optional[CRS] = None,
res: Optional[float] = None,
bands: Optional[Sequence[str]] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
cache: bool = True,
):

super().__init__(
root,
crs,
res,
bands,
transforms,
cache,
)

self.target_var = target_var
self.size = _to_tuple(size)
self.units = units
# convert to meters
if units == Units.PIXELS:
self.size = (self.size[0] * self.res, self.size[1] * self.res)
self.normalize = normalize_target_var
print(self.size)
self.target_indexes = target_indexes
self.gdf = self._prepare_gdf(gdf)
self.rois = self.gdf['bboxes']


def _polygon_to_bbox(self, polygon):
bounds = list(polygon.bounds)
bounds[1], bounds[2] = bounds[2], bounds[1]
# keeping temporal coordinates from raster dataset
return BoundingBox(*bounds, self.index.bounds[4], self.index.bounds[5])

def _get_intersected_bboxes(self, gdf):
pathname = os.path.join(self.root, "**", self.filename_glob)
file_list = []
for filepath in glob.iglob(pathname, recursive=True):
file_list.append(filepath)
return gdf.loc[[intersects_with_img(gdf['bboxes'][i], file_list) for i in gdf.index]]


def _prepare_gdf(self, gdf):
# remove false geometries
gdf = gdf.loc[gdf['geometry']!=None]
gdf = gdf.to_crs(self.crs)
gdf = gdf.drop_duplicates()
# remove nas in target variable
gdf = gdf.dropna(subset=[self.target_var])

# if geodataframe has points, convert to square with buffer of self.size meters
if gdf.geom_type.unique() == "Point":
gdf.geometry = gdf.buffer(self.size[0], cap_style = 3)

# only conserves rois which intersect with the images from the dataset
gdf['bboxes'] = [self._polygon_to_bbox(gdf['geometry'][i]) for i in gdf.index]
gdf = self._get_intersected_bboxes(gdf)

if self.normalize:
gdf[self.target_var] = (gdf[self.target_var] - gdf[self.target_var].mean()) / gdf[self.target_var].std()

gdf.index = [i for i in range(len(gdf))]
if self.target_indexes is not None:
gdf = gdf.iloc[self.target_indexes][self.target_var]

return(gdf)

def __len__(self):
return len(self.gdf)

def __getitem__(self, idx):
"""Retrieve image/mask and metadata indexed by query.

Args:
index: Index of sample to fetch

Returns:
sample of image/mask, metadata and ground truth at that index

Raises:
IndexError: if query is not found in the index
"""
query = self.gdf.iloc[idx]['bboxes']
gt = self.gdf.iloc[idx][self.target_var]

sample = super().__getitem__(query)

if self.transforms is not None:
sample = self.transforms(sample)

sample["gt"] = gt

return sample