Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ability to load fewer polygons, start trying to fix colorbar #34

Merged
merged 5 commits into from
Nov 8, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 52 additions & 9 deletions src/sweets/_missing_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import pandas as pd
from dolphin._types import Filename
from matplotlib.colors import BoundaryNorm, ListedColormap
from osgeo import gdal
from shapely import geometry, intersection_all, union_all, wkt
from tqdm.contrib.concurrent import thread_map
Expand All @@ -24,7 +25,7 @@


def get_geodataframe(
gslc_files: Iterable[Filename], max_workers: int = 5
gslc_files: Iterable[Filename], max_workers: int = 5, one_per_burst: bool = True
) -> gpd.GeoDataFrame:
"""Get a GeoDataFrame of the CSLC footprints.

Expand All @@ -34,13 +35,30 @@ def get_geodataframe(
List of CSLC files.
max_workers : int
Number of threads to use.
one_per_burst : bool, default=True
If True, only keep one footprint per burst ID.
"""
gslc_files = list(gslc_files) # make sure generator doesn't deplete after first run
polygons = thread_map(get_cslc_polygon, gslc_files, max_workers=max_workers)
if one_per_burst:
from dolphin.opera_utils import group_by_burst

burst_to_file_list = group_by_burst(gslc_files)
slc_files = [file_list[0] for file_list in burst_to_file_list.values()]
unique_polygons = thread_map(
get_cslc_polygon, slc_files, max_workers=max_workers
)
assert len(unique_polygons) == len(burst_to_file_list)
# Repeat the polygons for each burst
polygons: list[geometry.Polygon] = []
for burst_id, p in zip(burst_to_file_list, unique_polygons):
for _ in range(len(burst_to_file_list[burst_id])):
polygons.append(p)
else:
polygons = thread_map(get_cslc_polygon, gslc_files, max_workers=max_workers)

gdf = gpd.GeoDataFrame(geometry=polygons, crs="EPSG:4326")
gdf["count"] = 1
gdf["filename"] = [p.stem for p in gslc_files]
gdf["filename"] = [Path(p).stem for p in gslc_files]
gdf["date"] = pd.to_datetime(gdf.filename.str.split("_").str[3])
gdf["burst_id"] = gdf.filename.str[:15]
return gdf
Expand Down Expand Up @@ -68,13 +86,19 @@ def get_cslc_polygon(


def get_common_dates(
*, gslc_files: Optional[Sequence[Filename]] = None, gdf=None
*,
gslc_files: Optional[Sequence[Filename]] = None,
gdf=None,
max_workers: int = 5,
one_per_burst: bool = True,
) -> list[str]:
"""Get the date common to all GSLCs."""
if gdf is None:
if gslc_files is None:
raise ValueError("Need `gdf` or `gslc_files`")
gdf = get_geodataframe(gslc_files)
gdf = get_geodataframe(
gslc_files, max_workers=max_workers, one_per_burst=one_per_burst
)

grouped_by_burst = _get_per_burst_df(gdf)
common_dates = list(
Expand Down Expand Up @@ -109,27 +133,46 @@ def plot_count_per_burst(
*,
gdf: Optional[gpd.GeoDataFrame] = None,
gslc_files: Optional[Sequence[Filename]] = None,
one_per_burst: bool = True,
ax: Optional[plt.Axes] = None,
) -> None:
"""Plot the number of GSLC files found per burst."""
if gdf is None:
if gslc_files is None:
raise ValueError("Need `gdf` or `gslc_files`")
gdf = get_geodataframe(gslc_files)
gdf = get_geodataframe(gslc_files, one_per_burst=one_per_burst)
gdf_grouped = _get_per_burst_df(gdf)

if ax is None:
fig, ax = plt.subplots(ncols=1)

# Make a unique colormap for the specific count values
unique_counts = np.unique(gdf_grouped["count"])

cmap = ListedColormap(plt.cm.tab10(np.linspace(0, 1, len(unique_counts))))
boundaries = np.concatenate([[unique_counts[0] - 1], unique_counts + 1])
norm = BoundaryNorm(boundaries, cmap.N)

kwds = dict(
column="count",
legend=True,
cmap="tab10",
legend_kwds={"label": "Count", "orientation": "horizontal"},
legend=False,
cmap=cmap,
norm=norm,
linewidth=0.8,
edgecolor="0.8",
)

gdf_grouped.plot(ax=ax, **kwds)
cbar = plt.colorbar(
plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, orientation="horizontal"
)
cbar.set_label("Count")
cbar_ticks = [
(boundaries[i] + boundaries[i + 1]) / 2 for i in range(len(boundaries) - 1)
]
cbar.set_ticks(cbar_ticks)
cbar.set_ticklabels(unique_counts)

return gdf_grouped


Expand Down
Loading