Skip to content

Commit

Permalink
fix: ensure MaxGCP weights are non complex
Browse files Browse the repository at this point in the history
  • Loading branch information
zietzm committed Aug 29, 2024
1 parent 5487bad commit ea910d6
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions src/maxgcp/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,19 @@ def fit_heritability(cov_G: Matrix, cov_P: Matrix) -> Matrix:
check_matrix_inputs(cov_G, cov_P)

cov_G_sqrt: NDArray = scipy.linalg.sqrtm(cov_G) # type: ignore
lhs = cov_G_sqrt @ np.linalg.pinv(cov_P) @ cov_G_sqrt
_, evecs = np.linalg.eig(lhs)
weights = np.linalg.pinv(cov_G_sqrt) @ evecs
if np.iscomplexobj(cov_G_sqrt):
raise ValueError("Input covariance matrix must be real")
lhs = cov_G_sqrt @ np.linalg.inv(cov_P) @ cov_G_sqrt
if np.iscomplexobj(lhs):
raise ValueError("LHS matrix must be real")
if not np.allclose(lhs, lhs.T, atol=1e-5):
raise ValueError("LHS matrix must be symmetric")
_, evecs = np.linalg.eigh(lhs)
if np.iscomplexobj(evecs):
raise ValueError("Eigenvectors must be real")
weights = np.linalg.inv(cov_G_sqrt) @ evecs
if np.iscomplexobj(weights):
raise ValueError("Eigenvectors must be real")
weights = np.asarray(weights)

# Normalize weights so that projections have unit variance
Expand Down

0 comments on commit ea910d6

Please sign in to comment.