diff --git a/tests/test_mask.py b/tests/test_mask.py index b71fc64e..16e149e3 100644 --- a/tests/test_mask.py +++ b/tests/test_mask.py @@ -14,11 +14,11 @@ def test_convgnp_mask(nps): conv_receptive_field=0.5, conv_layers=1, conv_channels=1, - # Dividing by the density channel makes the forward very sensitive to the - # numerics. - divide_by_density=False, + # A large margin and `float64`s help with numerical stability. + margin=2, + dtype=nps.dtype64, ) - xc, yc, xt, yt = generate_data(nps) + xc, yc, xt, yt = generate_data(nps, dtype=nps.dtype64) # Predict without the final three points. pred = model(xc[:, :, :-3], yc[:, :, :-3], xt) diff --git a/tests/util.py b/tests/util.py index 276bcca1..accee0aa 100644 --- a/tests/util.py +++ b/tests/util.py @@ -51,7 +51,7 @@ def approx( nps_tf.dtype64 = tf.float64 -@pytest.fixture(params=[nps_torch, nps_tf], scope="module") +@pytest.fixture(params=[nps_tf, nps_torch], scope="module") def nps(request): return request.param @@ -64,14 +64,17 @@ def generate_data( n_context=5, n_target=7, binary=False, + dtype=None, ): - xc = B.randn(nps.dtype, batch_size, dim_x, n_context) - yc = B.randn(nps.dtype, batch_size, dim_y, n_context) - xt = B.randn(nps.dtype, batch_size, dim_x, n_target) - yt = B.randn(nps.dtype, batch_size, dim_y, n_target) + if dtype is None: + dtype = nps.dtype + xc = B.randn(dtype, batch_size, dim_x, n_context) + yc = B.randn(dtype, batch_size, dim_y, n_context) + xt = B.randn(dtype, batch_size, dim_x, n_target) + yt = B.randn(dtype, batch_size, dim_y, n_target) if binary: - yc = B.cast(nps.dtype, yc >= 0) - yt = B.cast(nps.dtype, yt >= 0) + yc = B.cast(dtype, yc >= 0) + yt = B.cast(dtype, yt >= 0) return xc, yc, xt, yt