Skip to content

Commit

Permalink
Code refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexeyPechnikov committed Aug 27, 2024
1 parent b9306e5 commit e7205df
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 415 deletions.
35 changes: 2 additions & 33 deletions pygmtsar/pygmtsar/Stack_dem.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .Stack_reframe import Stack_reframe
from .PRM import PRM
from .tqdm_dask import tqdm_dask
from .utils import utils

class Stack_dem(Stack_reframe):

Expand Down Expand Up @@ -60,45 +61,13 @@ def get_geoid(self, grid=None):
See EGM96 geoid heights on http://icgem.gfz-potsdam.de/tom_longtime
"""
import xarray as xr
import dask.array as da
import os
import importlib.resources as resources
import warnings
# suppress Dask warning "RuntimeWarning: invalid value encountered in divide"
warnings.filterwarnings('ignore')
warnings.filterwarnings('ignore', module='dask')
warnings.filterwarnings('ignore', module='dask.core')

# use outer variable geoid
def interpolate_chunk(grid_chunk, grid_lat_chunk, grid_lon_chunk, method='cubic'):
dlat, dlon = float(geoid.lat.diff('lat')[0]), float(geoid.lon.diff('lon')[0])
geoid_chunk = geoid.sel(
lat=slice(grid_lat_chunk[0]-2*dlat, grid_lat_chunk[-1]+2*dlat),
lon=slice(grid_lon_chunk[0]-2*dlon, grid_lon_chunk[-1]+2*dlon)
).compute()
#print ('geoid_chunk', geoid_chunk)
return geoid_chunk.interp({'lat': grid_lat_chunk, 'lon': grid_lon_chunk}, method=method)


with resources.as_file(resources.files('pygmtsar.data') / 'geoid_egm96_icgem.grd') as geoid_filename:
geoid = xr.open_dataarray(geoid_filename, engine=self.netcdf_engine, chunks=self.netcdf_chunksize).rename({'y': 'lat', 'x': 'lon'})
if grid is not None:
# Xarray interpolation struggles with large grids
#geoid = geoid.interp_like(grid.coords, method='linear')
# grid.data is needed only to prevent excessive memory usage during interpolation
geoid_grid = da.blockwise(
interpolate_chunk,
'ij',
grid.data,
'ij',
grid.lat.data,
'i',
grid.lon.data,
'j',
dtype=geoid.dtype
)
return xr.DataArray(geoid_grid, coords=grid.coords).rename(geoid.name)

return utils.interp2d_like(geoid, grid)
return geoid

def set_dem(self, dem_filename):
Expand Down
303 changes: 30 additions & 273 deletions pygmtsar/pygmtsar/Stack_phasediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .Stack_topo import Stack_topo
from .tqdm_dask import tqdm_dask
from .PRM import PRM
from .utils import utils

class Stack_phasediff(Stack_topo):

Expand Down Expand Up @@ -442,20 +443,10 @@ def correlation(self, phase, intensity, debug=False):
# return xr.concat(stack, dim='pair').assign_coords(ref=coord_ref, rep=coord_rep, pair=coord_pair).rename('phasediff')

def phasediff(self, pairs, data='auto', topo='auto', phase=None, method='nearest', joblib_backend=None, debug=False):
import pandas as pd
import dask
import dask.dataframe
#import dask
import dask.array as da
import xarray as xr
import numpy as np
#from tqdm.auto import tqdm
import joblib
from joblib.externals import loky
loky.get_reusable_executor(kill_workers=True).shutdown(wait=True)
import warnings
# suppress Dask warning "RuntimeWarning: invalid value encountered in divide"
warnings.filterwarnings('ignore')
warnings.filterwarnings('ignore', module='dask')
warnings.filterwarnings('ignore', module='dask.core')

if debug:
print ('DEBUG: phasediff')
Expand All @@ -467,272 +458,38 @@ def phasediff(self, pairs, data='auto', topo='auto', phase=None, method='nearest
pairs, dates = self.get_pairs(pairs, dates=True)
pairs = pairs[['ref', 'rep']].astype(str).values

if isinstance(topo, str) and topo == 'auto':
topo = self.get_topo()

# calculate the combined earth curvature and topography correction
def calc_drho(rho, topo, earth_radius, height, b, alpha, Bx):
sina = np.sin(alpha)
cosa = np.cos(alpha)
c = earth_radius + height
# compute the look angle using equation (C26) in Appendix C
# GMTSAR uses long double here
#ret = earth_radius + topo.astype(np.longdouble)
ret = earth_radius + topo
cost = ((rho**2 + c**2 - ret**2) / (2. * rho * c))
#if (cost >= 1.)
# die("calc_drho", "cost >= 0");
sint = np.sqrt(1. - cost**2)
# Compute the offset effect from non-parallel orbit
term1 = rho**2 + b**2 - 2 * rho * b * (sint * cosa - cost * sina) - Bx**2
drho = -rho + np.sqrt(term1)
del term1, sint, cost, ret, c, cosa, sina
return drho

def block_phasediff(date1, date2, prm1, prm2, ylim, xlim):
# use outer variables date, stack_prm
# disable "distributed.utils_perf - WARNING - full garbage collections ..."
try:
from dask.distributed import utils_perf
utils_perf.disable_gc_diagnosis()
except ImportError:
from distributed.gc import disable_gc_diagnosis
disable_gc_diagnosis()
import warnings
# suppress Dask warning "RuntimeWarning: invalid value encountered in divide"
warnings.filterwarnings('ignore')
warnings.filterwarnings('ignore', module='dask')
warnings.filterwarnings('ignore', module='dask.core')

# for lazy Dask ddataframes
#prm1 = PRM(prm1)
#prm2 = PRM(prm2)
#prm1, prm2 = stack_prm[stack_idx]
#data1, data2 = stack_data[stack_idx]
data1 = data.sel(date=date1)
data2 = data.sel(date=date2)

# convert indices 0.5, 1.5,... to 0,1,... for easy calculations
block_data1 = data1.isel(y=slice(ylim[0], ylim[1]), x=slice(xlim[0], xlim[1])).compute(n_workers=1)
block_data2 = data2.isel(y=slice(ylim[0], ylim[1]), x=slice(xlim[0], xlim[1])).compute(n_workers=1)
del data1, data2

if abs(block_data1).sum() == 0:
intf = np.nan * xr.zeros_like(block_data1)
del block_data1, block_data2
return intf

ys = block_data1.y.astype(int)
xs = block_data1.x.astype(int)

block_data1 = block_data1.assign_coords(y=ys, x=xs)
block_data2 = block_data2.assign_coords(y=ys, x=xs)

if isinstance(topo, xr.DataArray):
dy, dx = topo.y.diff('y').item(0), topo.x.diff('x').item(0)

# use outer variables topo, data1, data2, prm1, prm2
# build topo block
if not isinstance(topo, xr.DataArray):
# topography is a constant, typically, zero
block_topo = topo * xr.ones_like(block_data1, dtype=np.float32)
elif dy == 1 and dx == 1:
# topography is already in the original resolution
block_topo = topo.isel(y=slice(ylim[0], ylim[1]), x=slice(xlim[0], xlim[1]))\
.compute(n_workers=1)\
.fillna(0)\
.assign_coords(y=ys, x=xs)
else:
# topography resolution is different, interpolation with extrapolation required
# convert indices 0.5, 1.5,... to 0,1,... for easy calculations
# fill NaNs by zero because typically DEM is missed outside of land areas
block_topo = topo.sel(y=slice(ys[0]-2*dy, ys[-1]+2*dy), x=slice(xs[0]-2*dx, xs[-1]+2*dx))\
.compute(n_workers=1)\
.fillna(0)\
.interp({'y': block_data1.y, 'x': block_data1.x}, method=method, kwargs={'fill_value': 'extrapolate'})\
.assign_coords(y=ys, x=xs)

if phase is not None:
dy, dx = phase.y.diff('y').item(0), phase.x.diff('x').item(0)
if dy == 1 and dx == 1:
# phase is already in the original resolution
block_phase = phase.sel(pair=f'{date1} {date2}').isel(y=slice(ylim[0], ylim[1]), x=slice(xlim[0], xlim[1]))\
.compute(n_workers=1)\
.assign_coords(y=ys, x=xs)
else:
# phase resolution is different, interpolation with extrapolation required
# convert indices 0.5, 1.5,... to 0,1,... for easy calculations
block_phase = phase.sel(pair=f'{date1} {date2}').sel(y=slice(ys[0]-2*dy, ys[-1]+2*dy), x=slice(xs[0]-2*dx, xs[-1]+2*dx))\
.compute(n_workers=1)\
.interp({'y': block_data1.y, 'x': block_data1.x}, method=method, kwargs={'fill_value': 'extrapolate'})\
.assign_coords(y=ys, x=xs)
# set dimensions
xdim = prm1.get('num_rng_bins')
ydim = prm1.get('num_patches') * prm1.get('num_valid_az')

# set heights
htc = prm1.get('SC_height')
ht0 = prm1.get('SC_height_start')
htf = prm1.get('SC_height_end')

# compute the time span and the time spacing
tspan = 86400 * abs(prm2.get('SC_clock_stop') - prm2.get('SC_clock_start'))
assert (tspan >= 0.01) and (prm2.get('PRF') >= 0.01), 'Check sc_clock_start, sc_clock_end, or PRF'

from scipy import constants
# setup the default parameters
# constant from GMTSAR code for consistency
#SOL = 299792456.0
drange = constants.speed_of_light / (2 * prm2.get('rng_samp_rate'))
#drange = SOL / (2 * prm2.get('rng_samp_rate'))
alpha = prm2.get('alpha_start') * np.pi / 180
cnst = -4 * np.pi / prm2.get('radar_wavelength')

# calculate initial baselines
Bh0 = prm2.get('baseline_start') * np.cos(prm2.get('alpha_start') * np.pi / 180)
Bv0 = prm2.get('baseline_start') * np.sin(prm2.get('alpha_start') * np.pi / 180)
Bhf = prm2.get('baseline_end') * np.cos(prm2.get('alpha_end') * np.pi / 180)
Bvf = prm2.get('baseline_end') * np.sin(prm2.get('alpha_end') * np.pi / 180)
Bx0 = prm2.get('B_offset_start')
Bxf = prm2.get('B_offset_end')

# first case is quadratic baseline model, second case is default linear model
if prm2.get('baseline_center') != 0 or prm2.get('alpha_center') != 0 or prm2.get('B_offset_center') != 0:
Bhc = prm2.get('baseline_center') * np.cos(prm2.get('alpha_center') * np.pi / 180)
Bvc = prm2.get('baseline_center') * np.sin(prm2.get('alpha_center') * np.pi / 180)
Bxc = prm2.get('B_offset_center')

dBh = (-3 * Bh0 + 4 * Bhc - Bhf) / tspan
dBv = (-3 * Bv0 + 4 * Bvc - Bvf) / tspan
ddBh = (2 * Bh0 - 4 * Bhc + 2 * Bhf) / (tspan * tspan)
ddBv = (2 * Bv0 - 4 * Bvc + 2 * Bvf) / (tspan * tspan)

dBx = (-3 * Bx0 + 4 * Bxc - Bxf) / tspan
ddBx = (2 * Bx0 - 4 * Bxc + 2 * Bxf) / (tspan * tspan)
else:
dBh = (Bhf - Bh0) / tspan
dBv = (Bvf - Bv0) / tspan
dBx = (Bxf - Bx0) / tspan
ddBh = ddBv = ddBx = 0

# calculate height increment
dht = (-3 * ht0 + 4 * htc - htf) / tspan
ddht = (2 * ht0 - 4 * htc + 2 * htf) / (tspan * tspan)

# multiply xr.ones_like(topo) for correct broadcasting
near_range = xr.ones_like(block_topo)*(prm1.get('near_range') + \
block_topo.x * (1 + prm1.get('stretch_r')) * drange) + \
xr.ones_like(block_topo)*(block_topo.y * prm1.get('a_stretch_r') * drange)

# calculate the change in baseline and height along the frame if topoflag is on
time = block_topo.y * tspan / (ydim - 1)
Bh = Bh0 + dBh * time + ddBh * time**2
Bv = Bv0 + dBv * time + ddBv * time**2
Bx = Bx0 + dBx * time + ddBx * time**2
B = np.sqrt(Bh * Bh + Bv * Bv)
alpha = np.arctan2(Bv, Bh)
height = ht0 + dht * time + ddht * time**2

# calculate the combined earth curvature and topography correction
drho = calc_drho(near_range, block_topo, prm1.get('earth_radius'), height, B, alpha, Bx)

# make topographic and model phase corrections
# GMTSAR uses float32 complex operations with precision loss
#phase_shift = np.exp(1j * (cnst * drho).astype(np.float32))
if phase is not None:
phase_shift = np.exp(1j * (cnst * drho - block_phase))
# or the same expression in other way
#phase_shift = np.exp(1j * (cnst * drho)) * np.exp(-1j * block_phase)
del block_phase
else:
phase_shift = np.exp(1j * (cnst * drho))
del block_topo, near_range, drho, height, B, alpha, Bx, Bv, Bh, time

# calculate phase difference
intf = block_data1 * phase_shift * np.conj(block_data2)
del block_data1, block_data2, phase_shift
return intf.astype(np.complex64)

# # prepare lazy PRM
# # this is the "true way" but processing is ~40% slower due to additional Dask tasks
# def block_prms(date1, date2):
# prm1 = self.PRM(date1)
# prm2 = self.PRM(date2)
# prm2.set(prm1.SAT_baseline(prm2, tail=9)).fix_aligned()
# prm1.set(prm1.SAT_baseline(prm1).sel('SC_height','SC_height_start','SC_height_end')).fix_aligned()
# return (prm1.df, prm2.df)
# # Define metadata explicitly to match PRM dataframe
# prm_meta = pd.DataFrame(columns=['name', 'value']).astype({'name': 'str', 'value': 'object'}).set_index('name')

# immediately prepare PRM
# here is some delay on the function call but the actual processing is faster
def prepare_prms(pair):
date1, date2 = pair
prm1 = self.PRM(date1)
prm2 = self.PRM(date2)
prm2.set(prm1.SAT_baseline(prm2, tail=9)).fix_aligned()
prm1.set(prm1.SAT_baseline(prm1).sel('SC_height','SC_height_start','SC_height_end')).fix_aligned()
return {(date1, date2): (prm1, prm2)}

#with self.tqdm_joblib(tqdm(desc=f'Pre-Processing PRM', total=len(pairs))) as progress_bar:
prms = joblib.Parallel(n_jobs=-1, backend=joblib_backend)(joblib.delayed(prepare_prms)(pair) for pair in pairs)
# convert the list of dicts to a single dict
prms = {k: v for d in prms for k, v in d.items()}

if isinstance(data, str) and data == 'auto':
# open datafiles required for all the pairs
data = self.open_data(dates)

# define blocks
chunks = data.chunks
ychunks, xchunks = chunks[1], chunks[2]
ychunks = np.concatenate([[0], np.cumsum(ychunks)])
xchunks = np.concatenate([[0], np.cumsum(xchunks)])
ylims = [(y1, y2) for y1, y2 in zip(ychunks, ychunks[1:])]
xlims = [(x1, x2) for x1, x2 in zip(xchunks, xchunks[1:])]
#print ('ylims', ylims)
#print ('xlims', xlims)

stack = []
for stack_idx, pair in enumerate(pairs):
date1, date2 = pair

# Create a Dask DataFrame with provided metadata for each Dask block
#prms = dask.delayed(block_prms)(date1, date2)
#prm1 = dask.dataframe.from_delayed(dask.delayed(prms[0]), meta=prm_meta)
#prm2 = dask.dataframe.from_delayed(dask.delayed(prms[1]), meta=prm_meta)
prm1, prm2 = prms[(date1, date2)]

if topo is None:
# calculation is straightforward and does not require delayed wrappers
intf = (data.sel(date=date1) * np.conj(data.sel(date=date2)))
else:
blocks_total = []
for ylim in ylims:
blocks = []
for xlim in xlims:
delayed = dask.delayed(block_phasediff)(date1, date2, prm1, prm2, ylim, xlim)
block = dask.array.from_delayed(delayed,
shape=((ylim[1]-ylim[0]), (xlim[1]-xlim[0])),
dtype=np.complex64)
blocks.append(block)
del block, delayed
blocks_total.append(blocks)
del blocks
intf = xr.DataArray(dask.array.block(blocks_total), coords={'y': data.y, 'x': data.x})
del blocks_total

# add to stack
stack.append(intf)
# cleanup
del intf, prm1, prm2
del prms

coord_pair = [' '.join(pair) for pair in pairs]
coord_ref = xr.DataArray(pd.to_datetime(pairs[:,0]), coords={'pair': coord_pair})
coord_rep = xr.DataArray(pd.to_datetime(pairs[:,1]), coords={'pair': coord_pair})
# interpret the topo argument as topography, otherwise, use it as topography phase
if isinstance(topo, str) and topo == 'auto':
topo = utils.interp2d_like(self.get_topo(), data, method=method, kwargs={'fill_value': 'extrapolate'})
if (isinstance(topo, xr.DataArray) and topo.name=='topo'):
phase_topo = self.topo_phase(pairs, topo, grid=data, method=method)
else:
phase_topo = topo

return xr.concat(stack, dim='pair').assign_coords(ref=coord_ref, rep=coord_rep, pair=coord_pair).rename('phase')
if phase is not None:
phase_real = utils.interp2d_like(phase, grid=data, method=method, kwargs={'fill_value': 'extrapolate'})
else:
phase_real = 0
#phase_real = len(pairs)*[0]

# calculate phase difference
data1 = data.sel(date=pairs[:,0]).drop_vars('date').rename({'date': 'pair'})
data2 = data.sel(date=pairs[:,1]).drop_vars('date').rename({'date': 'pair'})
out = (data1 * phase_topo * np.exp(-1j * phase_real) * da.conj(data2)).astype(np.complex64)
del phase_topo, phase_real, data1, data2

# # calculate phase difference
# phase_dask = da.stack([(data.sel(date=pair[0]).drop_vars('date') \
# * phase_topo[idx] * np.exp(-1j * phase_real[idx]) \
# * da.conj(data.sel(date=pair[1]).drop_vars('date'))) for idx, pair in enumerate(pairs)], axis=0)
# out = xr.DataArray(phase_dask, coords=phase_topo.coords)
# del phase_topo, phase_real, phase_dask

return out.astype(np.complex64).rename('phase')

def goldstein(self, phase, corr, psize=32, debug=False):
import xarray as xr
Expand Down
Loading

0 comments on commit e7205df

Please sign in to comment.