Skip to content

Commit

Permalink
Merge pull request #34 from scottstanie/missing-data-plots
Browse files Browse the repository at this point in the history
add ability to load fewer polygons, start trying to fix colorbar
  • Loading branch information
scottstanie authored Nov 8, 2023
2 parents b9377e0 + da8d233 commit 00f7dfb
Showing 1 changed file with 52 additions and 9 deletions.
61 changes: 52 additions & 9 deletions src/sweets/_missing_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import pandas as pd
from dolphin._types import Filename
from matplotlib.colors import BoundaryNorm, ListedColormap
from osgeo import gdal
from shapely import geometry, intersection_all, union_all, wkt
from tqdm.contrib.concurrent import thread_map
Expand All @@ -24,7 +25,7 @@


def get_geodataframe(
gslc_files: Iterable[Filename], max_workers: int = 5
gslc_files: Iterable[Filename], max_workers: int = 5, one_per_burst: bool = True
) -> gpd.GeoDataFrame:
"""Get a GeoDataFrame of the CSLC footprints.
Expand All @@ -34,13 +35,30 @@ def get_geodataframe(
List of CSLC files.
max_workers : int
Number of threads to use.
one_per_burst : bool, default=True
If True, only keep one footprint per burst ID.
"""
gslc_files = list(gslc_files) # make sure generator doesn't deplete after first run
polygons = thread_map(get_cslc_polygon, gslc_files, max_workers=max_workers)
if one_per_burst:
from dolphin.opera_utils import group_by_burst

burst_to_file_list = group_by_burst(gslc_files)
slc_files = [file_list[0] for file_list in burst_to_file_list.values()]
unique_polygons = thread_map(
get_cslc_polygon, slc_files, max_workers=max_workers
)
assert len(unique_polygons) == len(burst_to_file_list)
# Repeat the polygons for each burst
polygons: list[geometry.Polygon] = []
for burst_id, p in zip(burst_to_file_list, unique_polygons):
for _ in range(len(burst_to_file_list[burst_id])):
polygons.append(p)
else:
polygons = thread_map(get_cslc_polygon, gslc_files, max_workers=max_workers)

gdf = gpd.GeoDataFrame(geometry=polygons, crs="EPSG:4326")
gdf["count"] = 1
gdf["filename"] = [p.stem for p in gslc_files]
gdf["filename"] = [Path(p).stem for p in gslc_files]
gdf["date"] = pd.to_datetime(gdf.filename.str.split("_").str[3])
gdf["burst_id"] = gdf.filename.str[:15]
return gdf
Expand Down Expand Up @@ -68,13 +86,19 @@ def get_cslc_polygon(


def get_common_dates(
*, gslc_files: Optional[Sequence[Filename]] = None, gdf=None
*,
gslc_files: Optional[Sequence[Filename]] = None,
gdf=None,
max_workers: int = 5,
one_per_burst: bool = True,
) -> list[str]:
"""Get the date common to all GSLCs."""
if gdf is None:
if gslc_files is None:
raise ValueError("Need `gdf` or `gslc_files`")
gdf = get_geodataframe(gslc_files)
gdf = get_geodataframe(
gslc_files, max_workers=max_workers, one_per_burst=one_per_burst
)

grouped_by_burst = _get_per_burst_df(gdf)
common_dates = list(
Expand Down Expand Up @@ -109,27 +133,46 @@ def plot_count_per_burst(
*,
gdf: Optional[gpd.GeoDataFrame] = None,
gslc_files: Optional[Sequence[Filename]] = None,
one_per_burst: bool = True,
ax: Optional[plt.Axes] = None,
) -> None:
"""Plot the number of GSLC files found per burst."""
if gdf is None:
if gslc_files is None:
raise ValueError("Need `gdf` or `gslc_files`")
gdf = get_geodataframe(gslc_files)
gdf = get_geodataframe(gslc_files, one_per_burst=one_per_burst)
gdf_grouped = _get_per_burst_df(gdf)

if ax is None:
fig, ax = plt.subplots(ncols=1)

# Make a unique colormap for the specific count values
unique_counts = np.unique(gdf_grouped["count"])

cmap = ListedColormap(plt.cm.tab10(np.linspace(0, 1, len(unique_counts))))
boundaries = np.concatenate([[unique_counts[0] - 1], unique_counts + 1])
norm = BoundaryNorm(boundaries, cmap.N)

kwds = dict(
column="count",
legend=True,
cmap="tab10",
legend_kwds={"label": "Count", "orientation": "horizontal"},
legend=False,
cmap=cmap,
norm=norm,
linewidth=0.8,
edgecolor="0.8",
)

gdf_grouped.plot(ax=ax, **kwds)
cbar = plt.colorbar(
plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, orientation="horizontal"
)
cbar.set_label("Count")
cbar_ticks = [
(boundaries[i] + boundaries[i + 1]) / 2 for i in range(len(boundaries) - 1)
]
cbar.set_ticks(cbar_ticks)
cbar.set_ticklabels(unique_counts)

return gdf_grouped


Expand Down

0 comments on commit 00f7dfb

Please sign in to comment.