Skip to content

Commit

Permalink
Merge pull request #476 from simonsobs/gapfill_wrapper
Browse files Browse the repository at this point in the history
Gapfill wrapper
  • Loading branch information
msilvafe authored Oct 11, 2023
2 parents c33a979 + 61ffc07 commit 5796df2
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 1 deletion.
30 changes: 30 additions & 0 deletions sotodlib/preprocess/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,35 @@ class Demodulate(_Preprocess):
def process(self, aman, proc_aman):
hwp.demod_tod(aman, **self.process_cfgs)


class GlitchFill(_Preprocess):
"""Fill glitches. All process configs go to `fill_glitches`.
.. autofunction:: sotodlib.tod_ops.gapfill.fill_glitches
"""
name = "glitchfill"

def process(self, aman):
pcfgs = np.fromiter(self.process_cfgs.keys(), dtype='U16')
if 'glitch_flags' in pcfgs:
flags = aman.flags[self.process_cfgs["glitch_flags"]]
pcfgs = np.delete(pcfgs, np.where(pcfgs == 'glitch_flags'))
else:
flags = None

if 'signal' in pcfgs:
signal = aman[self.process_cfgs["signal"]]
pcfgs = np.delete(pcfgs, np.where(pcfgs == 'signal'))
else:
signal = None

args = {}
for pcfg in pcfgs:
args[pcfg] = self.process_cfgs[pcfg]

tod_ops.gapfill.fill_glitches(aman, signal=signal, glitch_flags=flags, **args)


_Preprocess.register(Trends.name, Trends)
_Preprocess.register(FFTTrim.name, FFTTrim)
_Preprocess.register(Detrend.name, Detrend)
Expand All @@ -255,3 +284,4 @@ def process(self, aman, proc_aman):
_Preprocess.register(SubtractHWPSS.name, SubtractHWPSS)
_Preprocess.register(Apodize.name, Apodize)
_Preprocess.register(Demodulate.name, Demodulate)
_Preprocess.register(GlitchFill.name, GlitchFill)
80 changes: 80 additions & 0 deletions sotodlib/tod_ops/gapfill.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
import so3g
from . import pca
import logging

logger = logging.getLogger(__name__)

class Extract:
"""Container for storage of sparse sub-segments of a vector. This
Expand Down Expand Up @@ -380,3 +383,80 @@ def get_contaminated_ranges(good_flags, bad_flags):
rs.add_interval(int(i0), int(i1))
return contam


def fill_glitches(aman, nbuf=10, use_pca=False, modes=3, signal=None,
glitch_flags=None, wrap=True):
"""
This function fills pre-computed glitches provided by the caller in
time-ordered data using either a polynomial (default) or PCA-based
approach. Wraps the other functions in the ``tod_ops.gapfill`` module.
Args
-----
aman : AxisManager
AxisManager to fill glitches in
nbuf : int
Number of buffer samples to use in polynomial gap filling.
use_pca : bool
Whether or not to fill glitches using pca model. Default is False
modes : int
Number of modes in the pca to use if pca=True. Default is 3.
signal : ndarray or None
Array of data to fill glitches in. If None then uses ``aman.signal``.
Default is None.
glitch_flags : RangesMatrix or None
RangesMatrix containing flags to use for gap filling. If None then
uses ``aman.flags.glitches``.
wrap : bool or str
If True wraps new field called ``gap_filled``, if False returns the
gap filled array, if a string wraps new field with provided name.
Returns
-------
signal : ndarray
Returns ndarray with gaps filled from input signal.
"""
# Process Args
if signal is None:
sig = np.copy(aman.signal)
else:
sig = np.copy(signal)

if glitch_flags is None:
glitch_flags = aman.flags.glitches

# Polyfill
gaps = get_gap_fill(aman, nbuf=nbuf, flags=glitch_flags,
signal=np.float32(sig))
sig = gaps.swap(aman, signal=sig)

#PCA Fill
if use_pca:
if modes > aman.dets.count:
logger.warning(f'modes = {modes} > number of detectors = ' +
f'{aman.dets.count}, setting modes = number of ' +
'detectors')
modes = aman.dets.count
# fill with poly fill before PCA
gaps = get_gap_fill(aman, nbuf=nbuf, flags=glitch_flags,
signal=np.float32(sig))
sig = gaps.swap(aman, signal=sig)
# PCA fill
mod = pca.get_pca_model(tod=aman, n_modes=modes,
signal=sig)
gfill = get_gap_model(tod=aman, model=mod, flags=glitch_flags)
sig = gfill.swap(aman, signal=sig)

# Wrap and Return
if isinstance(wrap, str):
if wrap in aman._assignments:
aman.move(wrap, None)
aman.wrap(wrap, sig, [(0, 'dets'), (1, 'samps')])
return sig
elif wrap:
if 'gap_filled' in aman._assignments:
aman.move('gap_filled', None)
aman.wrap('gap_filled', sig, [(0, 'dets'), (1, 'samps')])
return sig
else:
return sig
3 changes: 3 additions & 0 deletions sotodlib/tod_ops/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,11 @@ def get_pca(tod=None, cov=None, signal=None, wrap=None):
output = core.AxisManager(dets, mode_axis)
output.wrap('cov', cov, [(0, dets.name), (1, dets.name)])

# Note eig will sometimes return complex eigenvalues.
E, R = np.linalg.eig(cov) # eigh nans sometimes...
E[np.isnan(E)] = 0.
E, R = E.real, R.real

idx = np.argsort(-E)
output.wrap('E', E[idx], [(0, mode_axis.name)])
output.wrap('R', R[:, idx], [(0, dets.name), (1, mode_axis.name)])
Expand Down
56 changes: 55 additions & 1 deletion tests/test_tod_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from numpy.testing import assert_array_equal, assert_allclose

from sotodlib import core, tod_ops
from sotodlib import core, tod_ops, sim_flags
import so3g

from ._helpers import mpi_multi
Expand Down Expand Up @@ -40,6 +40,38 @@ def get_tod(sig_type='trendy'):
raise RuntimeError(f'sig_type={sig_type}?')
return tod


def get_glitchy_tod(ts, noise_amp=0, ndets=2, npoly=3, poly_coeffs=None):
"""Returns axis manager to test fill_glitches"""
fake_signal = np.zeros((ndets, len(ts)))
input_sig = np.zeros((ndets, len(ts)))
if poly_coeffs is None:
poly_coeffs = np.random.uniform(0.5, 1.6, npoly)*1e-1
poly_sig = np.polyval(poly_coeffs, ts-np.mean(ts))
for nd in range(ndets):
input_sig[nd] = poly_sig
fake_signal[nd] = poly_sig
noise = np.random.normal(0, noise_amp, size=len(ts))
fake_signal[nd] += noise

dets = ['det%i' % i for i in range(ndets)]

tod_fake = core.AxisManager(core.LabelAxis('dets', vals=dets),
core.OffsetAxis('samps', count=len(ts)))
tod_fake.wrap('timestamps', ts, axis_map=[(0, 'samps')])
tod_fake.wrap('signal', np.atleast_2d(fake_signal),
axis_map=[(0, 'dets'), (1, 'samps')])
tod_fake.wrap('inputsignal', np.atleast_2d(input_sig),
axis_map=[(0, 'dets'), (1, 'samps')])
flgs = core.AxisManager()
tod_fake.wrap('flags', flgs)
params = {'n_glitches': 10, 'sig_n_glitch': 10, 'h_glitc h': 10,
'sig_h_glitch': 2}
sim_flags.add_random_glitches(tod_fake, params=params, signal='signal',
flag='glitches', overwrite=False)
return tod_fake


class FactorsTest(unittest.TestCase):
def test_inf(self):
f = tod_ops.fft_ops.find_inferior_integer
Expand Down Expand Up @@ -121,6 +153,28 @@ def test_basic(self):
assert_allclose(tod.signal[1][gap_mask], sentinel)
# ... check "extraction" has model values.
assert_allclose(ex[1].data, sig[gap_mask], atol=atol)

def test_fillglitches(self):
"""Tests fill glitches wrapper function"""
ts = np.arange(0, 1*60, 1/200)
aman = get_glitchy_tod(ts, ndets=100)
# test poly fill
up, mg = False, False
glitch_filled = tod_ops.gapfill.fill_glitches(aman, use_pca=up,
wrap=mg)
self.assertTrue(np.max(np.abs(glitch_filled-aman.inputsignal)) < 1e-3)

# test pca fill
up, mg = True, False
glitch_filled = tod_ops.gapfill.fill_glitches(aman, use_pca=up,
wrap=mg)
print(np.max(np.abs(glitch_filled-aman.inputsignal)))

# test wrap new field
up, mg = False, True
glitch_filled = tod_ops.gapfill.fill_glitches(aman, use_pca=up,
wrap=mg)
self.assertTrue('gap_filled' in aman._assignments)

class FilterTest(unittest.TestCase):
def test_basic(self):
Expand Down

0 comments on commit 5796df2

Please sign in to comment.