Skip to content

Commit

Permalink
tests: add tests for batched LS
Browse files Browse the repository at this point in the history
  • Loading branch information
lgarrison committed Mar 8, 2024
1 parent 78f0aa6 commit d4f2cb5
Showing 1 changed file with 136 additions and 8 deletions.
144 changes: 136 additions & 8 deletions tests/test_ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import pytest


def gen_data(N=100, seed=5043, dtype=np.float64):
def gen_data(N=100, nobj=None, seed=5043, dtype=np.float64):
rng = np.random.default_rng(seed)

t = np.sort(rng.random(N, dtype=dtype)) * 123
y = np.sin(20 * t)
dy = rng.random(N, dtype=dtype) * 0.1 + 0.01
freqs = rng.random((nobj, 1) if nobj else 1, dtype=dtype) * 10 + 1
y = np.sin(freqs * t)
dy = rng.random(y.shape, dtype=dtype) * 0.1 + 0.01

t.setflags(write=False)
y.setflags(write=False)
Expand All @@ -26,6 +27,16 @@ def bench_data():
return gen_data(N=3_000)


@pytest.fixture(scope='module')
def batched_data():
return gen_data(nobj=100)


@pytest.fixture(scope='module')
def batched_bench_data():
return gen_data(N=3_000, nobj=100)


def astropy(
t,
y,
Expand Down Expand Up @@ -69,7 +80,30 @@ def test_lombscargle(data, Nf):

dtype = t.dtype

# breakpoint()
assert np.allclose(nifty_res, brute_res, rtol=1e-9 if dtype == np.float64 else 1e-5)


def test_batched(batched_data, Nf=1000):
"""Check various batching modes"""
import nifty_ls

t = batched_data['t']
y_batch = batched_data['y']
dy_batch = batched_data['dy']

fmin = 0.1
fmax = 10.0

nifty_res = nifty_ls.lombscargle(t, y_batch, dy_batch, fmin=fmin, fmax=fmax, Nf=Nf)

brute_res = np.empty((len(y_batch), Nf), dtype=y_batch.dtype)
for i in range(len(y_batch)):
brute_res[i] = astropy(
t, y_batch[i], dy_batch[i], fmin, fmax, Nf, use_fft=False
)

dtype = t.dtype

assert np.allclose(nifty_res, brute_res, rtol=1e-9 if dtype == np.float64 else 1e-5)


Expand All @@ -92,7 +126,43 @@ def test_astropy_hook(data, Nf=1000):
assert np.allclose(astropy_power, nifty_power)


# TODO: batch test
@pytest.mark.parametrize('Nf', [1_000])
def test_no_cpp_helpers(data, batched_data, Nf):
"""Check that the _no_cpp_helpers flag works as expected for batched and unbatched"""
import nifty_ls

fmin = 0.1
fmax = 10.0

nifty_power = nifty_ls.lombscargle(
**data, fmin=fmin, fmax=fmax, Nf=Nf, backend_kwargs=dict(_no_cpp_helpers=False)
)

nocpp_power = nifty_ls.lombscargle(
**data, fmin=fmin, fmax=fmax, Nf=Nf, backend_kwargs=dict(_no_cpp_helpers=True)
)

assert np.allclose(nifty_power, nocpp_power)

nifty_power_batched = nifty_ls.lombscargle(
**batched_data,
fmin=fmin,
fmax=fmax,
Nf=Nf,
backend_kwargs=dict(_no_cpp_helpers=False),
)

nocpp_power_batched = nifty_ls.lombscargle(
**batched_data,
fmin=fmin,
fmax=fmax,
Nf=Nf,
backend_kwargs=dict(_no_cpp_helpers=True),
)

assert np.allclose(nifty_power_batched, nocpp_power_batched)


# TODO: cuda test
# TODO: center_data, fit_mean, normalization tests

Expand All @@ -106,13 +176,20 @@ class TestPerf:
this integration is.
"""

fmin = 0.1
fmax = 10.0

def test_nifty(self, bench_data, Nf, benchmark):
import nifty_ls

benchmark(nifty_ls.lombscargle, **bench_data, fmin=0.1, fmax=10.0, Nf=Nf)
benchmark(
nifty_ls.lombscargle, **bench_data, fmin=self.fmin, fmax=self.fmax, Nf=Nf
)

def test_astropy(self, bench_data, Nf, benchmark):
benchmark(astropy, **bench_data, fmin=0.1, fmax=10.0, Nf=Nf, use_fft=True)
benchmark(
astropy, **bench_data, fmin=self.fmin, fmax=self.fmax, Nf=Nf, use_fft=True
)

# Usually this benchmark isn't very useful, since one will always use the
# compiled extensions in practice, but if looking at the performance
Expand All @@ -121,4 +198,55 @@ def test_astropy(self, bench_data, Nf, benchmark):
# import nifty_ls

# benchmark(nifty_ls.lombscargle, **bench_data, fmin=0.1, fmax=10.0, Nf=Nf,
# backend_kwargs=dict(no_cpp_helpers=True))
# backend_kwargs=dict(_no_cpp_helpers=True))


@pytest.mark.parametrize('Nf', [1_000])
class TestBatchedPerf:
fmin = 0.1
fmax = 10.0

def test_batched(self, batched_bench_data, Nf, benchmark):
import nifty_ls

benchmark(
nifty_ls.lombscargle,
**batched_bench_data,
fmin=self.fmin,
fmax=self.fmax,
Nf=Nf,
)

def test_unbatched(self, batched_bench_data, Nf, benchmark):
import nifty_ls

t = batched_bench_data['t']
y_batch = batched_bench_data['y']
dy_batch = batched_bench_data['dy']

def _nifty():
for i in range(len(y_batch)):
nifty_ls.lombscargle(
t, y_batch[i], dy_batch[i], fmin=self.fmin, fmax=self.fmax, Nf=Nf
)

benchmark(_nifty)

def test_astropy_unbatched(self, batched_bench_data, Nf, benchmark):
t = batched_bench_data['t']
y_batch = batched_bench_data['y']
dy_batch = batched_bench_data['dy']

def _astropy():
for i in range(len(y_batch)):
astropy(
t,
y_batch[i],
dy_batch[i],
fmin=self.fmin,
fmax=self.fmax,
Nf=Nf,
use_fft=True,
)

benchmark(_astropy)

0 comments on commit d4f2cb5

Please sign in to comment.