Skip to content

Commit

Permalink
Use sgkit.distarray for gwas_linear_regression
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Sep 30, 2024
1 parent d522c0f commit 51a8ffa
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cubed.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ jobs:
- name: Test with pytest
run: |
pytest -v sgkit/tests/test_{aggregation,hwe}.py -k 'test_count_call_alleles or test_hwep or test_sample_stats or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed
pytest -v sgkit/tests/test_{aggregation,association,hwe}.py -k 'test_count_call_alleles or test_gwas_linear_regression or test_hwep or test_sample_stats or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed
3 changes: 3 additions & 0 deletions sgkit/distarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ def astype(x, dtype, /, *, copy=True): # pragma: no cover
if not copy and dtype == x.dtype:
return x
return x.astype(dtype=dtype, copy=copy)

# dask doesn't have concat required by the array API
concat = concatenate # noqa: F405
17 changes: 10 additions & 7 deletions sgkit/stats/association.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from dataclasses import dataclass
from typing import Hashable, Optional, Sequence, Union

import dask.array as da
import numpy as np
from dask.array import Array, stats
from scipy import stats
from xarray import Dataset, concat

import sgkit.distarray as da
from sgkit.distarray import Array

from .. import variables
from ..typing import ArrayLike
from ..utils import conditional_merge_datasets, create_dataset
Expand Down Expand Up @@ -78,18 +80,18 @@ def linear_regression(
# from projection require no extra terms in variance
# estimate for loop covariates (columns of G), which is
# only true when an intercept is present.
XLPS = (XLP**2).sum(axis=0, keepdims=True).T
XLPS = da.sum(XLP**2, axis=0, keepdims=True).T
assert XLPS.shape == (n_loop_covar, 1)
B = (XLP.T @ YP) / XLPS
assert B.shape == (n_loop_covar, n_outcome)

# Compute residuals for each loop covariate and outcome separately
YR = YP[:, np.newaxis, :] - XLP[..., np.newaxis] * B[np.newaxis, ...]
assert YR.shape == (n_obs, n_loop_covar, n_outcome)
RSS = (YR**2).sum(axis=0)
RSS = da.sum(YR**2, axis=0)
assert RSS.shape == (n_loop_covar, n_outcome)
# Get t-statistics for coefficient estimates
T = B / np.sqrt(RSS / dof / XLPS)
T = B / da.sqrt(RSS / dof / XLPS)
assert T.shape == (n_loop_covar, n_outcome)

# Match to p-values
Expand All @@ -102,7 +104,8 @@ def linear_regression(
dtype="float64",
)
assert P.shape == (n_loop_covar, n_outcome)
P = np.asarray(P, like=T)
if hasattr(T, "__array_function__"):
P = np.asarray(P, like=T)
return LinearRegressionResult(beta=B, t_value=T, p_value=P)


Expand Down Expand Up @@ -216,7 +219,7 @@ def gwas_linear_regression(
else:
X = da.asarray(concat_2d(ds[list(covariates)], dims=("samples", "covariates")))
if add_intercept:
X = da.concatenate([da.ones((X.shape[0], 1), dtype=X.dtype), X], axis=1)
X = da.concat([da.ones((X.shape[0], 1), dtype=X.dtype), X], axis=1)
# Note: dask qr decomp (used by lstsq) requires no chunking in one
# dimension, and because dim 0 will be far greater than the number
# of covariates for the large majority of use cases, chunking
Expand Down
2 changes: 1 addition & 1 deletion sgkit/stats/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def assert_array_shape(x: ArrayLike, *args: int) -> None:


def map_blocks_asnumpy(x: Array) -> Array:
if da.utils.is_cupy_type(x._meta): # pragma: no cover
if hasattr(x, "_meta") and da.utils.is_cupy_type(x._meta): # pragma: no cover
import cupy as cp # type: ignore[import]

x = x.map_blocks(cp.asnumpy)
Expand Down
4 changes: 2 additions & 2 deletions sgkit/tests/test_association.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple

import dask.array as da
import numpy as np
import pandas as pd
import pytest
Expand All @@ -11,6 +10,7 @@
from pandas import DataFrame
from xarray import Dataset

import sgkit.distarray as da
from sgkit.stats.association import (
gwas_linear_regression,
linear_regression,
Expand Down Expand Up @@ -263,7 +263,7 @@ def test_gwas_linear_regression__scalar_vars(ds: xr.Dataset) -> None:
res_list = gwas_linear_regression(
ds, dosage="dosage", covariates=["covar_0"], traits=["trait_0"]
)
xr.testing.assert_equal(res_scalar, res_list)
xr.testing.assert_allclose(res_scalar, res_list)


def test_gwas_linear_regression__raise_on_no_intercept_and_empty_covariates():
Expand Down

0 comments on commit 51a8ffa

Please sign in to comment.