From ea910d6bb95cb7a13e2d620b4ffcf928e307ef81 Mon Sep 17 00:00:00 2001 From: zietzm Date: Thu, 29 Aug 2024 15:41:04 -0700 Subject: [PATCH] fix: ensure MaxGCP weights are non complex --- src/maxgcp/estimators.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/maxgcp/estimators.py b/src/maxgcp/estimators.py index 23640df..e62ae85 100644 --- a/src/maxgcp/estimators.py +++ b/src/maxgcp/estimators.py @@ -61,9 +61,19 @@ def fit_heritability(cov_G: Matrix, cov_P: Matrix) -> Matrix: check_matrix_inputs(cov_G, cov_P) cov_G_sqrt: NDArray = scipy.linalg.sqrtm(cov_G) # type: ignore - lhs = cov_G_sqrt @ np.linalg.pinv(cov_P) @ cov_G_sqrt - _, evecs = np.linalg.eig(lhs) - weights = np.linalg.pinv(cov_G_sqrt) @ evecs + if np.iscomplexobj(cov_G_sqrt): + raise ValueError("Input covariance matrix must be real") + lhs = cov_G_sqrt @ np.linalg.inv(cov_P) @ cov_G_sqrt + if np.iscomplexobj(lhs): + raise ValueError("LHS matrix must be real") + if not np.allclose(lhs, lhs.T, atol=1e-5): + raise ValueError("LHS matrix must be symmetric") + _, evecs = np.linalg.eigh(lhs) + if np.iscomplexobj(evecs): + raise ValueError("Eigenvectors must be real") + weights = np.linalg.inv(cov_G_sqrt) @ evecs + if np.iscomplexobj(weights): + raise ValueError("Eigenvectors must be real") weights = np.asarray(weights) # Normalize weights so that projections have unit variance