Skip to content

Commit

Permalink
rewrite partition() to fit windows to functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ntessore committed Aug 1, 2023
1 parent 004b26f commit fb73dad
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 63 deletions.
120 changes: 95 additions & 25 deletions glass/shells.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,45 +331,115 @@ def restrict(z: ArrayLike1D, f: ArrayLike1D, w: RadialWindow
return zr, fr


def partition(z: ArrayLike1D, f: ArrayLike1D, ws: Sequence[RadialWindow]
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
'''Partition a function by a sequence of windows.
Partitions the given function into a sequence of functions
restricted to each window function.
The function :math:`f(z)` is given by redshifts ``z`` of shape
*(N,)* and function values ``f`` of shape *(..., N)*, with any
number of leading axes allowed.
The window functions are given by the sequence ``ws`` of
def partition(z: ArrayLike,
f: ArrayLike,
ws: Sequence[RadialWindow],
*,
method: str = "lstsq",
) -> ArrayLike:
"""Partition a function by a sequence of windows.
Returns a vector of weights :math:`x_1, x_2, \\ldots` such that the
weighted sum of normalised radial window functions :math:`x_1 \\,
w_1(z) + x_2 \\, w_2(z) + \\ldots` approximates the given function
:math:`f(z)`.
The function :math:`f(z)` is given by redshifts *z* of shape *(N,)*
and function values *f* of shape *(..., N)*, with any number of
leading axes allowed.
The window functions are given by the sequence *ws* of
:class:`RadialWindow` or compatible entries.
The partitioned functions have redshifts that are the union of the
redshifts of the original function and each window over the support
of said window. Intermediate function values are found by linear
interpolation
Parameters
----------
z, f : array_like
The function to be partitioned.
The function to be partitioned. If *f* is multi-dimensional,
its last axis must agree with *z*.
ws : sequence of :class:`RadialWindow`
Ordered sequence of window functions for the partition.
method : {"lstsq", "restrict"}
Method for the partition. See notes for description.
Returns
-------
zp, fp : list of array
The partitioned functions, ordered as the given windows.
x : array_like
Weights of the partition. If *f* is multi-dimensional, the
leading axes of *x* match those of *f*.
Notes
-----
Formally, if :math:`w_i` are the normalised window functions,
:math:`f` is the target function, and :math:`z_i` is a redshift grid
with intervals :math:`\\Delta z_i`, the partition problem seeks an
approximate solution of
.. math::
\\begin{pmatrix}
w_1(z_1) \\Delta z_1 & w_2(z_1) \\, \\Delta z_1 & \\cdots \\\\
w_1(z_2) \\Delta z_2 & w_2(z_2) \\, \\Delta z_2 & \\cdots \\\\
\\vdots & \\vdots & \\ddots
\\end{pmatrix} \\, \\begin{pmatrix}
x_1 \\\\ x_2 \\\\ \\vdots
\\end{pmatrix} = \\begin{pmatrix}
f(z_1) \\, \\Delta z_1 \\\\ f(z_2) \\, \\Delta z_2 \\\\ \\vdots
\\end{pmatrix} \\;.
The redshift grid is the union of the given array *z* and the
redshift arrays of all window functions. Intermediate function
values are found by linear interpolation.
If ``method="lstsq"``, obtain a partition from a least-squares
solution. This will more closely match the shape of the input
function, but the normalisation might differ.
If ``method="restrict"``, obtain a partition by integrating the
restriction (using :func:`restrict`) of the function :math:`f` to
each window. This will more closely match the normalisation of the
input function, but the shape might differ.
"""
try:
partition_method = globals()[f"partition_{method}"]
except KeyError:
raise ValueError(f"invalid method: {method}") from None
return partition_method(z, f, ws)


def partition_lstsq(z: ArrayLike, f: ArrayLike, ws: Sequence[RadialWindow]
) -> ArrayLike:
"""Least-squares partition."""

# compute the union of all given redshift grids
zp = z
for w in ws:
zp = np.union1d(zp, w.za)

'''
# compute grid spacing
dz = np.gradient(zp)

# create the window function matrix
a = [np.interp(zp, za, wa, left=0., right=0.) for za, wa, _ in ws]
a = a/np.trapz(a, zp, axis=-1)[..., None]
a = a*dz

# create the target vector of distribution values
b = ndinterp(zp, z, f, left=0., right=0.)
b = b*dz

# return least-squares fit
return np.linalg.lstsq(a.T, b.T, rcond=None)[0].T


def partition_restrict(z: ArrayLike, f: ArrayLike, ws: Sequence[RadialWindow]
) -> ArrayLike:
"""Partition by restriction and integration."""

zp, fp = [], []
ngal = []
for w in ws:
zr, fr = restrict(z, f, w)
zp.append(zr)
fp.append(fr)
return zp, fp
ngal.append(np.trapz(fr, zr, axis=-1))
return np.transpose(ngal)


def redshift_grid(zmin, zmax, *, dz=None, num=None):
Expand Down
38 changes: 0 additions & 38 deletions glass/test/test_shells.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import numpy.testing as npt


def test_tophat_windows():
Expand Down Expand Up @@ -47,40 +46,3 @@ def test_restrict():
i = np.searchsorted(zr, zi)
assert zr[i] == zi
assert fr[i] == fi*np.interp(zi, w.za, w.wa)


def test_partition():
from glass.shells import partition, RadialWindow

# Gaussian test function
z = np.linspace(0., 5., 1000)
f = np.exp(-((z - 2.)/0.5)**2/2)

# overlapping triangular weight functions
ws = [RadialWindow(za=[0., 1., 2.], wa=[0., 1., 0.], zeff=None),
RadialWindow(za=[1., 2., 3.], wa=[0., 1., 0.], zeff=None),
RadialWindow(za=[2., 3., 4.], wa=[0., 1., 0.], zeff=None),
RadialWindow(za=[3., 4., 5.], wa=[0., 1., 0.], zeff=None)]

zp, fp = partition(z, f, ws)

assert len(zp) == len(fp) == len(ws)

for zr, w in zip(zp, ws):
assert np.all((zr >= w.za[0]) & (zr <= w.za[-1]))

for zr, fr, w in zip(zp, fp, ws):
f_ = np.interp(zr, z, f, left=0., right=0.)
w_ = np.interp(zr, w.za, w.wa, left=0., right=0.)
npt.assert_allclose(fr, f_*w_)

f_ = sum(np.interp(z, zr, fr, left=0., right=0.)
for zr, fr in zip(zp, fp))

# first and last points have zero total weight
assert f_[0] == f_[-1] == 0.

# find first and last index where total weight becomes unity
i, j = np.searchsorted(z, [ws[0].za[1], ws[-1].za[1]])

npt.assert_allclose(f_[i:j], f[i:j], atol=1e-15)

0 comments on commit fb73dad

Please sign in to comment.