From bb176d6466f58919f89a450a0a773acf1257fa6d Mon Sep 17 00:00:00 2001 From: Alexey Pechnikov Date: Tue, 24 Sep 2024 01:42:10 +0700 Subject: [PATCH] Code refactoring to use Stack.PRM() for original subswaths and Stack.PRM_merged() for a single virtual raster. There is no need to store the merged PRMs on disk because these can be easily generated on the fly. --- pygmtsar/pygmtsar/IO.py | 10 +- pygmtsar/pygmtsar/Stack_align.py | 110 +-------------------- pygmtsar/pygmtsar/Stack_dem.py | 16 +--- pygmtsar/pygmtsar/Stack_geocode.py | 4 +- pygmtsar/pygmtsar/Stack_incidence.py | 6 +- pygmtsar/pygmtsar/Stack_prm.py | 116 ++++++++++++++++++++++- pygmtsar/pygmtsar/Stack_sbas.py | 4 +- pygmtsar/pygmtsar/Stack_tidal.py | 2 +- pygmtsar/pygmtsar/Stack_topo.py | 17 ++-- pygmtsar/pygmtsar/Stack_trans.py | 4 +- pygmtsar/pygmtsar/Stack_unwrap_snaphu.py | 2 +- 11 files changed, 148 insertions(+), 143 deletions(-) diff --git a/pygmtsar/pygmtsar/IO.py b/pygmtsar/pygmtsar/IO.py index 435209db..40f586e0 100644 --- a/pygmtsar/pygmtsar/IO.py +++ b/pygmtsar/pygmtsar/IO.py @@ -282,6 +282,7 @@ def get_filenames(self, pairs, name, add_subswath=False): # return ds_scaled.where(ds_scaled != 0) # 2.5e-07 is Sentinel-1 scale factor + # use original PRM files to get binary subswath file locations def open_data(self, dates=None, scale=2.5e-07, debug=False): import xarray as xr import pandas as pd @@ -301,10 +302,6 @@ def open_data(self, dates=None, scale=2.5e-07, debug=False): if not isinstance(subswaths, (str, int)): subswaths = ''.join(map(str, subswaths)) - # DEM extent in radar coordinates, merged reference PRM required - #print ('minx, miny, maxx, maxy', minx, miny, maxx, maxy) - extent_ra = np.round(self.get_extent_ra(subswath=subswaths[0]).bounds).astype(int) - if len(subswaths) == 1: # stack single subswath stack = [] @@ -320,7 +317,7 @@ def open_data(self, dates=None, scale=2.5e-07, debug=False): else: #offsets = {'bottoms': bottoms, 'lefts': lefts, 'rights': rights, 'bottom': minh, 'extent': [maxy, maxx], 'ylims': ylims, 'xlims': xlims} - offsets = self.subswaths_offsets(debug=debug) + offsets = self.prm_offsets(debug=debug) maxy, maxx = offsets['extent'] minh = offsets['bottom'] @@ -354,6 +351,9 @@ def open_data(self, dates=None, scale=2.5e-07, debug=False): stack.append(slc.assign_coords(date=date)) del slc + # DEM extent in radar coordinates, merged reference PRM required + #print ('minx, miny, maxx, maxy', minx, miny, maxx, maxy) + extent_ra = np.round(self.get_extent_ra().bounds).astype(int) # minx, miny, maxx, maxy = extent_ra ds = xr.concat(stack, dim='date').assign(date=pd.to_datetime(dates))\ .sel(y=slice(extent_ra[1], extent_ra[3]), x=slice(extent_ra[0], extent_ra[2])) \ diff --git a/pygmtsar/pygmtsar/Stack_align.py b/pygmtsar/pygmtsar/Stack_align.py index f16e411a..8f521f99 100644 --- a/pygmtsar/pygmtsar/Stack_align.py +++ b/pygmtsar/pygmtsar/Stack_align.py @@ -73,7 +73,10 @@ def _get_topo_llt(self, subswath, degrees, debug=False): warnings.filterwarnings('ignore') # add buffer around the cropped area for borders interpolation - dem_area = self.get_dem(subswath) + dem_area = self.get_dem() + + # TBD: crop dem to subswath + ny = int(np.round(degrees/dem_area.lat.diff('lat')[0])) nx = int(np.round(degrees/dem_area.lon.diff('lon')[0])) if debug: @@ -269,111 +272,6 @@ def _align_rep_subswath(self, subswath, date=None, degrees=12.0/3600, debug=Fals #if os.path.exists(filename): os.remove(filename) - def subswaths_offsets(self, debug=False): - import xarray as xr - import numpy as np - from scipy import constants - - subswaths = self.get_subswaths() - if not isinstance(subswaths, (str, int)): - subswaths = ''.join(map(str, subswaths)) - - if len(subswaths) == 1: - prm = self.PRM() - maxx, yvalid, num_patch = prm.get('num_rng_bins', 'num_valid_az', 'num_patches') - maxy = yvalid * num_patch - offsets = {'bottom': 0, 'extent': [maxy, maxx]} - if debug: - print ('offsets', offsets) - return offsets - - # calculate the offsets to merge subswaths - prms = [] - ylims = [] - xlims = [] - for subswath in subswaths: - #print (subswath) - prm = self.PRM(subswath=subswath) - prms.append(prm) - ylims.append(prm.get('num_valid_az')) - xlims.append(prm.get('num_rng_bins')) - - assert len(np.unique([prm.get('PRF') for prm in prms])), 'Image PRFs are not consistent' - assert len(np.unique([prm.get('rng_samp_rate') for prm in prms])), 'Image range sampling rates are not consistent' - - bottoms = [0] + [int(np.round(((prm.get('clock_start') - prms[0].get('clock_start')) * 86400 * prms[0].get('PRF')))) for prm in prms[1:]] - # head123: 0, 466, -408 - if debug: - print ('bottoms init', bottoms) - # minh: -408 - minh = min(bottoms) - if debug: - print ('minh', minh) - #head123: 408, 874, 0 - bottoms = np.asarray(bottoms) - minh - if debug: - print ('bottoms', bottoms) - - #ovl12,23: 2690, 2558 - ovls = [prm1.get('num_rng_bins') - \ - int(np.round(((prm2.get('near_range') - prm1.get('near_range')) / (constants.speed_of_light/ prm1.get('rng_samp_rate') / 2)))) \ - for (prm1, prm2) in zip(prms[:-1], prms[1:])] - if debug: - print ('ovls', ovls) - - #Writing the grid files..Size(69158x13075)... - #maxy: 13075 - # for SLC - maxy = max([prm.get('num_valid_az') + bottom for prm, bottom in zip(prms, bottoms)]) - if debug: - print ('maxy', maxy) - maxx = sum([prm.get('num_rng_bins') - ovl - 1 for prm, ovl in zip(prms, [-1] + ovls)]) - if debug: - print ('maxx', maxx) - - #Stitching location n1 = 1045 - #Stitching location n2 = 935 - ns = [np.ceil(-prm.get('rshift') + prm.get('first_sample') + 150.0).astype(int) for prm in prms[1:]] - ns = [10 if n < 10 else n for n in ns] - if debug: - print ('ns', ns) - - # left and right coordinates for every subswath valid area - lefts = [] - rights = [] - - # 1st - xlim = prms[0].get('num_rng_bins') - ovls[0] + ns[0] - lefts.append(0) - rights.append(xlim) - - # 2nd - if len(prms) == 2: - xlim = prms[1].get('num_rng_bins') - 1 - else: - # for 3 subswaths - xlim = prms[1].get('num_rng_bins') - ovls[1] + ns[1] - lefts.append(ns[0]) - rights.append(xlim) - - # 3rd - if len(prms) == 3: - xlim = prms[2].get('num_rng_bins') - 2 - lefts.append(ns[1]) - rights.append(xlim) - - # check and merge SLCs - sumx = sum([right-left for right, left in zip(rights, lefts)]) - if debug: - print ('assert maxx == sum(...)', maxx, sumx) - assert maxx == sumx, 'Incorrect output grid range dimension size' - - offsets = {'bottoms': bottoms, 'lefts': lefts, 'rights': rights, 'bottom': minh, 'extent': [maxy, maxx], 'ylims': ylims, 'xlims': xlims} - if debug: - print ('offsets', offsets) - - return offsets - def baseline_table(self, n_jobs=-1, debug=False): """ Generates a baseline table for Sentinel-1 data, containing dates and baseline components. diff --git a/pygmtsar/pygmtsar/Stack_dem.py b/pygmtsar/pygmtsar/Stack_dem.py index 11b282e6..4946ff3c 100644 --- a/pygmtsar/pygmtsar/Stack_dem.py +++ b/pygmtsar/pygmtsar/Stack_dem.py @@ -16,7 +16,7 @@ class Stack_dem(Stack_reframe): buffer_degrees = 0.02 - def get_extent_ra(self, subswath=None): + def get_extent_ra(self): """ minx, miny, maxx, maxy = np.round(geom.bounds).astype(int) """ @@ -25,7 +25,7 @@ def get_extent_ra(self, subswath=None): dem = self.get_dem() df = dem.isel(lon=[0,-1]).to_dataframe().reset_index() - geom = self.geocode(LineString(np.column_stack([df.lon, df.lat])), subswath=subswath) + geom = self.geocode(LineString(np.column_stack([df.lon, df.lat]))) return geom # def get_extent(self, grid=None, subswath=None): @@ -110,14 +110,12 @@ def set_dem(self, dem_filename): # 0.02 degrees works well worldwide but not in Siberia # minimum buffer size: 8 arc seconds for 90 m DEM # subswath argument is required for aligning - def get_dem(self, subswath=None): + def get_dem(self): """ Retrieve the digital elevation model (DEM) data. Parameters ---------- - subswath : str, optional - Subswath name. Default is None. Returns ------- @@ -131,12 +129,8 @@ def get_dem(self, subswath=None): Examples -------- - Get DEM for all the processed subswaths: topo_ll = stack.get_dem() - Get DEM for a single subswath IW1: - topo_ll = stack.get_dem(1) - Notes ----- This method retrieves the digital elevation model (DEM) data previously downloaded and stored in a NetCDF file. @@ -167,8 +161,8 @@ def get_dem(self, subswath=None): dem['lat'] = dem.lat.round(8) dem['lon'] = dem.lon.round(8) - # crop to reference scene and subswath if defined - bounds = self.get_bounds(self.get_reference(subswath)) + # crop to reference scene + bounds = self.get_bounds(self.get_reference()) return dem\ .transpose('lat','lon')\ .sel(lat=slice(bounds[1] - self.buffer_degrees, bounds[3] + self.buffer_degrees), diff --git a/pygmtsar/pygmtsar/Stack_geocode.py b/pygmtsar/pygmtsar/Stack_geocode.py index dd2f6e95..28af0044 100644 --- a/pygmtsar/pygmtsar/Stack_geocode.py +++ b/pygmtsar/pygmtsar/Stack_geocode.py @@ -56,7 +56,7 @@ def compute_geocode(self, coarsen=60.): # coarsen=4: # nearest: coords [array([596.42352295]), array([16978.65625])] # linear: coords [array([597.1080563]), array([16977.35608873])] - def geocode(self, geometry, subswath=None, z_offset=None): + def geocode(self, geometry, z_offset=None): """ Inverse geocode input geodataframe with 2D or 3D points. @@ -87,7 +87,7 @@ def geocode(self, geometry, subswath=None, z_offset=None): geometries = [geometry] dem = self.get_dem() - prm = self.PRM(subswath=subswath) + prm = self.PRM_merged() def coords_transform(coords): # uses external variables dem, prm diff --git a/pygmtsar/pygmtsar/Stack_incidence.py b/pygmtsar/pygmtsar/Stack_incidence.py index 1afee553..60a4c290 100644 --- a/pygmtsar/pygmtsar/Stack_incidence.py +++ b/pygmtsar/pygmtsar/Stack_incidence.py @@ -296,7 +296,7 @@ def los_displacement_mm(self, data): # constant is negative to make LOS = -1 * range change # constant is (1000 mm) / (4 * pi) - scale = -79.58 * self.PRM().get('radar_wavelength') + scale = -79.58 * self.PRM_merged().get('radar_wavelength') if isinstance(data, (list, tuple)): return scale*np.asarray(data) @@ -438,7 +438,7 @@ def elevation_m(self, data, baseline=1): # expected accuracy about 0.01% #wavelength, slant_range = self.PRM().get('radar_wavelength','SC_height') - wavelength, slant_range_start,slant_range_end = self.PRM().get('radar_wavelength', 'SC_height_start', 'SC_height_end') + wavelength, slant_range_start,slant_range_end = self.PRM_merged().get('radar_wavelength', 'SC_height_start', 'SC_height_end') incidence_angle = self.incidence_angle() slant_range = xr.DataArray(np.linspace(slant_range_start,slant_range_end, incidence_angle.shape[1]), @@ -465,7 +465,7 @@ def compute_satellite_look_vector(self, interactive=False): def SAT_look(z, lat, lon): coords = np.column_stack([lon.ravel(), lat.ravel(), z.ravel()]) # look_E look_N look_U - look = self.PRM().SAT_look(coords, binary=True)\ + look = self.PRM_merged().SAT_look(coords, binary=True)\ .astype(np.float32)\ .reshape(z.shape[0], z.shape[1], 6)[...,3:] return look diff --git a/pygmtsar/pygmtsar/Stack_prm.py b/pygmtsar/pygmtsar/Stack_prm.py index 7d94b08f..6cf45eec 100644 --- a/pygmtsar/pygmtsar/Stack_prm.py +++ b/pygmtsar/pygmtsar/Stack_prm.py @@ -41,4 +41,118 @@ def PRM(self, date=None, subswath=None): prefix = self.multistem_stem(subswath, date) filename = os.path.join(self.basedir, f'{prefix}.PRM') - return PRM.from_file(filename) \ No newline at end of file + return PRM.from_file(filename) + + def PRM_merged(self, date=None, offsets='auto'): + + if isinstance(offsets, str) and offsets == 'auto': + offsets = self.prm_offsets() + + maxy, maxx = offsets['extent'] + minh = offsets['bottom'] + return self.PRM(date=date).fix_merged(maxy, maxx, minh) + + def prm_offsets(self, debug=False): + import xarray as xr + import numpy as np + from scipy import constants + + subswaths = self.get_subswaths() + if not isinstance(subswaths, (str, int)): + subswaths = ''.join(map(str, subswaths)) + + if len(subswaths) == 1: + prm = self.PRM(subswath=int(subswaths)) + maxx, yvalid, num_patch = prm.get('num_rng_bins', 'num_valid_az', 'num_patches') + maxy = yvalid * num_patch + offsets = {'bottom': 0, 'extent': [maxy, maxx]} + if debug: + print ('offsets', offsets) + return offsets + + # calculate the offsets to merge subswaths + prms = [] + ylims = [] + xlims = [] + for subswath in subswaths: + #print (subswath) + prm = self.PRM(subswath=subswath) + prms.append(prm) + ylims.append(prm.get('num_valid_az')) + xlims.append(prm.get('num_rng_bins')) + + assert len(np.unique([prm.get('PRF') for prm in prms])), 'Image PRFs are not consistent' + assert len(np.unique([prm.get('rng_samp_rate') for prm in prms])), 'Image range sampling rates are not consistent' + + bottoms = [0] + [int(np.round(((prm.get('clock_start') - prms[0].get('clock_start')) * 86400 * prms[0].get('PRF')))) for prm in prms[1:]] + # head123: 0, 466, -408 + if debug: + print ('bottoms init', bottoms) + # minh: -408 + minh = min(bottoms) + if debug: + print ('minh', minh) + #head123: 408, 874, 0 + bottoms = np.asarray(bottoms) - minh + if debug: + print ('bottoms', bottoms) + + #ovl12,23: 2690, 2558 + ovls = [prm1.get('num_rng_bins') - \ + int(np.round(((prm2.get('near_range') - prm1.get('near_range')) / (constants.speed_of_light/ prm1.get('rng_samp_rate') / 2)))) \ + for (prm1, prm2) in zip(prms[:-1], prms[1:])] + if debug: + print ('ovls', ovls) + + #Writing the grid files..Size(69158x13075)... + #maxy: 13075 + # for SLC + maxy = max([prm.get('num_valid_az') + bottom for prm, bottom in zip(prms, bottoms)]) + if debug: + print ('maxy', maxy) + maxx = sum([prm.get('num_rng_bins') - ovl - 1 for prm, ovl in zip(prms, [-1] + ovls)]) + if debug: + print ('maxx', maxx) + + #Stitching location n1 = 1045 + #Stitching location n2 = 935 + ns = [np.ceil(-prm.get('rshift') + prm.get('first_sample') + 150.0).astype(int) for prm in prms[1:]] + ns = [10 if n < 10 else n for n in ns] + if debug: + print ('ns', ns) + + # left and right coordinates for every subswath valid area + lefts = [] + rights = [] + + # 1st + xlim = prms[0].get('num_rng_bins') - ovls[0] + ns[0] + lefts.append(0) + rights.append(xlim) + + # 2nd + if len(prms) == 2: + xlim = prms[1].get('num_rng_bins') - 1 + else: + # for 3 subswaths + xlim = prms[1].get('num_rng_bins') - ovls[1] + ns[1] + lefts.append(ns[0]) + rights.append(xlim) + + # 3rd + if len(prms) == 3: + xlim = prms[2].get('num_rng_bins') - 2 + lefts.append(ns[1]) + rights.append(xlim) + + # check and merge SLCs + sumx = sum([right-left for right, left in zip(rights, lefts)]) + if debug: + print ('assert maxx == sum(...)', maxx, sumx) + assert maxx == sumx, 'Incorrect output grid range dimension size' + + offsets = {'bottoms': bottoms, 'lefts': lefts, 'rights': rights, 'bottom': minh, 'extent': [maxy, maxx], 'ylims': ylims, 'xlims': xlims} + if debug: + print ('offsets', offsets) + + return offsets diff --git a/pygmtsar/pygmtsar/Stack_sbas.py b/pygmtsar/pygmtsar/Stack_sbas.py index 370e2ebd..410b3f9e 100644 --- a/pygmtsar/pygmtsar/Stack_sbas.py +++ b/pygmtsar/pygmtsar/Stack_sbas.py @@ -90,10 +90,10 @@ def baseline_table(dates): import pandas as pd import numpy as np - prm_ref = self.PRM() + prm_ref = self.PRM_merged() data = [] for date in dates: - prm_rep = self.PRM(date) + prm_rep = self.PRM_merged(date) BPL, BPR = prm_ref.SAT_baseline(prm_rep).get('B_parallel', 'B_perpendicular') data.append({'date':date, 'BPL':BPL, 'BPR':BPR}) df = pd.DataFrame(data).set_index('date') diff --git a/pygmtsar/pygmtsar/Stack_tidal.py b/pygmtsar/pygmtsar/Stack_tidal.py index 08f35369..7e619dec 100644 --- a/pygmtsar/pygmtsar/Stack_tidal.py +++ b/pygmtsar/pygmtsar/Stack_tidal.py @@ -271,7 +271,7 @@ def _tidal(self, date, grid): stdin_data = buffer.getvalue() #print ('stdin_data', stdin_data) - SC_clock_start, SC_clock_stop = self.PRM(date).get('SC_clock_start', 'SC_clock_stop') + SC_clock_start, SC_clock_stop = self.PRM_merged(date).get('SC_clock_start', 'SC_clock_stop') dt = (SC_clock_start + SC_clock_stop)/2 argv = ['solid_tide', str(dt)] #cwd = os.path.dirname(self.filename) if self.filename is not None else '.' diff --git a/pygmtsar/pygmtsar/Stack_topo.py b/pygmtsar/pygmtsar/Stack_topo.py index 933299eb..9c83f09b 100644 --- a/pygmtsar/pygmtsar/Stack_topo.py +++ b/pygmtsar/pygmtsar/Stack_topo.py @@ -128,7 +128,8 @@ def block_phase_dask(block_topo, y_chunk, x_chunk, prm1, prm2): # compute the time span and the time spacing tspan = 86400 * abs(prm2.get('SC_clock_stop') - prm2.get('SC_clock_start')) - assert (tspan >= 0.01) and (prm2.get('PRF') >= 0.01), 'Check sc_clock_start, sc_clock_end, or PRF' + assert (tspan >= 0.01) and (prm2.get('PRF') >= 0.01), \ + f"ERROR in sc_clock_start={prm2.get('SC_clock_start')}, sc_clock_stop={prm2.get('SC_clock_stop')}, or PRF={prm2.get('PRF')}" # setup the default parameters drange = constants.speed_of_light / (2 * prm2.get('rng_samp_rate')) @@ -193,19 +194,17 @@ def block_phase_dask(block_topo, y_chunk, x_chunk, prm1, prm2): # immediately prepare PRM # here is some delay on the function call but the actual processing is faster - offsets = self.subswaths_offsets(debug=debug) - maxy, maxx = offsets['extent'] - minh = offsets['bottom'] - - def prepare_prms(pair, *args): + # define offset once to apply to all the PRMs + offsets = self.prm_offsets(debug=debug) + def prepare_prms(pair, offsets): date1, date2 = pair - prm1 = self.PRM(date1).fix_merged(*args) - prm2 = self.PRM(date2).fix_merged(*args) + prm1 = self.PRM_merged(date1, offsets=offsets) + prm2 = self.PRM_merged(date2, offsets=offsets) prm2.set(prm1.SAT_baseline(prm2, tail=9)).fix_aligned() prm1.set(prm1.SAT_baseline(prm1).sel('SC_height','SC_height_start','SC_height_end')).fix_aligned() return (prm1, prm2) - prms = joblib.Parallel(n_jobs=-1)(joblib.delayed(prepare_prms)(pair, maxy, maxx, minh) for pair in pairs) + prms = joblib.Parallel(n_jobs=-1)(joblib.delayed(prepare_prms)(pair, offsets) for pair in pairs) # fill NaNs by 0 and expand to 3d topo2d = da.where(da.isnan(topo.data), 0, topo.data) diff --git a/pygmtsar/pygmtsar/Stack_trans.py b/pygmtsar/pygmtsar/Stack_trans.py index 011cc082..de923a46 100644 --- a/pygmtsar/pygmtsar/Stack_trans.py +++ b/pygmtsar/pygmtsar/Stack_trans.py @@ -17,7 +17,7 @@ def define_trans_grid(self, coarsen): # select radar coordinates extent #rng_max, yvalid, num_patch = self.PRM().get('num_rng_bins', 'num_valid_az', 'num_patches') #azi_max = yvalid * num_patch - azi_max, rng_max = self.subswaths_offsets()['extent'] + azi_max, rng_max = self.prm_offsets()['extent'] #print ('azi_max', azi_max, 'rng_max', rng_max) # this grid covers the full interferogram area # common single pixel resolution @@ -99,7 +99,7 @@ def compute_trans(self, coarsen, dem='auto', interactive=False): coarsen = self.get_coarsen(coarsen) - prm = self.PRM() + prm = self.PRM_merged() def SAT_llt2rat(lats, lons, zs): # for binary=True values outside of the scene missed and the array is not complete # 4th and 5th coordinates are the same as input lat, lon diff --git a/pygmtsar/pygmtsar/Stack_unwrap_snaphu.py b/pygmtsar/pygmtsar/Stack_unwrap_snaphu.py index 30e12206..b0cbf4c9 100644 --- a/pygmtsar/pygmtsar/Stack_unwrap_snaphu.py +++ b/pygmtsar/pygmtsar/Stack_unwrap_snaphu.py @@ -184,7 +184,7 @@ def snaphu_config(self, defomax=0, **kwargs): import os import joblib - tiledir = os.path.splitext(self.PRM().filename)[0] + tiledir = os.path.splitext(self.PRM_merged().filename)[0] n_jobs = joblib.cpu_count() conf_basic = f"""