diff --git a/DESCRIPTION b/DESCRIPTION index 6af1407..00d3587 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,8 +1,8 @@ Encoding: UTF-8 Type: Package Package: mr.mash.alpha -Version: 0.3.29 -Date: 2024-06-03 +Version: 0.3.34 +Date: 2024-09-15 Title: Multiple Regression with Multivariate Adaptive Shrinkage Description: Provides an implementation of methods for multivariate multiple regression with adaptive shrinkage priors. diff --git a/NAMESPACE b/NAMESPACE index b80beaf..182afdd 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -36,6 +36,7 @@ importFrom(mashr,mash_set_data) importFrom(matrixStats,colMeans2) importFrom(matrixStats,colSds) importFrom(matrixStats,colVars) +importFrom(methods,as) importFrom(mvtnorm,dmvnorm) importFrom(mvtnorm,rmvnorm) importFrom(parallel,makeCluster) diff --git a/R/mr_mash.R b/R/mr_mash.R index 7c503e6..73d6777 100644 --- a/R/mr_mash.R +++ b/R/mr_mash.R @@ -31,11 +31,17 @@ #' #' @param update_w0_method Method to update prior weights. Only EM is #' currently supported. +#' +#' @param w0_penalty K-vector of penalty terms (>=1) for each +#' mixture component. Default is all components are unpenalized. #' #' @param w0_threshold Drop mixture components with weight less than this value. #' Components are dropped at each iteration after 15 initial iterations. #' This is done to prevent from dropping some poetentially important #' components prematurely. +#' +#' @param update_w0_max_iter Maximum number of iterations for the update +#' of w0. #' #' @param update_V if \code{TRUE}, residual covariance is updated. #' @@ -163,8 +169,8 @@ #' mr.mash <- function(X, Y, S0, w0=rep(1/(length(S0)), length(S0)), V=NULL, mu1_init=matrix(0, nrow=ncol(X), ncol=ncol(Y)), tol=1e-4, convergence_criterion=c("mu1", "ELBO"), - max_iter=5000, update_w0=TRUE, update_w0_method="EM", - w0_threshold=0, compute_ELBO=TRUE, standardize=TRUE, verbose=TRUE, + max_iter=5000, update_w0=TRUE, update_w0_method="EM", w0_penalty=rep(1, length(S0)), + update_w0_max_iter=Inf, w0_threshold=0, compute_ELBO=TRUE, standardize=TRUE, verbose=TRUE, update_V=FALSE, update_V_method=c("full", "diagonal"), version=c("Rcpp", "R"), e=1e-8, ca_update_order=c("consecutive", "decreasing_logBF", "increasing_logBF", "random"), nthreads=as.integer(NA)) { @@ -429,14 +435,15 @@ mr.mash <- function(X, Y, S0, w0=rep(1/(length(S0)), length(S0)), V=NULL, } ##Update w0 if requested - if(update_w0){ - w0 <- update_weights_em(w1_t) + if(update_w0 && t <= update_w0_max_iter){ + w0 <- update_weights_em(w1_t, w0_penalty) #Drop components with mixture weight <= w0_threshold if(t>15 && any(w0 < w0_threshold)){ to_keep <- which(w0 >= w0_threshold) w0 <- w0[to_keep] w0 <- w0/sum(w0) + w0_penalty <- w0_penalty[to_keep] S0 <- S0[to_keep] if(length(to_keep) > 1){ comps <- filter_precomputed_quants(comps, to_keep, standardize, version) diff --git a/R/mr_mash_rss.R b/R/mr_mash_rss.R index efc240a..eaefbf6 100644 --- a/R/mr_mash_rss.R +++ b/R/mr_mash_rss.R @@ -43,10 +43,16 @@ #' #' @param update_w0_method Method to update prior weights. Only EM is #' currently supported. +#' +#' @param update_w0_max_iter Maximum number of iterations for the update +#' of w0. +#' +#' @param w0_penalty K-vector of penalty terms (>=1) for each +#' mixture component. Default is all components are unpenalized. #' #' @param w0_threshold Drop mixture components with weight less than this value. #' Components are dropped at each iteration after 15 initial iterations. -#' This is done to prevent from dropping some poetentially important +#' This is done to prevent from dropping some potentially important #' components prematurely. #' #' @param update_V if \code{TRUE}, residual covariance is updated. @@ -180,8 +186,8 @@ #' mr.mash.rss <- function(Bhat, Shat, Z, R, covY, n, S0, w0=rep(1/(length(S0)), length(S0)), V, mu1_init=NULL, tol=1e-4, convergence_criterion=c("mu1", "ELBO"), - max_iter=5000, update_w0=TRUE, update_w0_method="EM", - w0_threshold=0, compute_ELBO=TRUE, standardize=FALSE, verbose=TRUE, + max_iter=5000, update_w0=TRUE, update_w0_method="EM", w0_penalty=rep(1, length(S0)), + update_w0_max_iter=Inf, w0_threshold=0, compute_ELBO=TRUE, standardize=FALSE, verbose=TRUE, update_V=FALSE, update_V_method=c("full", "diagonal"), version=c("Rcpp", "R"), e=1e-8, ca_update_order=c("consecutive", "decreasing_logBF", "increasing_logBF", "random"), X_colmeans=NULL, Y_colmeans=NULL, check_R=TRUE, R_tol=1e-08, @@ -463,14 +469,15 @@ mr.mash.rss <- function(Bhat, Shat, Z, R, covY, n, S0, w0=rep(1/(length(S0)), le } ##Update w0 if requested - if(update_w0){ - w0 <- update_weights_em(w1_t) + if(update_w0 && t <= update_w0_max_iter){ + w0 <- update_weights_em(w1_t, w0_penalty) #Drop components with mixture weight <= w0_threshold if(t>15 && any(w0 < w0_threshold)){ to_keep <- which(w0 >= w0_threshold) w0 <- w0[to_keep] w0 <- w0/sum(w0) + w0_penalty <- w0_penalty[to_keep] S0 <- S0[to_keep] if(length(to_keep) > 1){ comps <- filter_precomputed_quants(comps, to_keep, standardize, version) @@ -483,7 +490,7 @@ mr.mash.rss <- function(Bhat, Shat, Z, R, covY, n, S0, w0=rep(1/(length(S0)), le } else { #some other component is the only one left stop("Only one component (different from the null) left. Consider lowering w0_threshold.") } - } + } } } diff --git a/R/mr_mash_updates.R b/R/mr_mash_updates.R index e902d1e..524aa93 100644 --- a/R/mr_mash_updates.R +++ b/R/mr_mash_updates.R @@ -202,13 +202,13 @@ update_V_fun <- function(Y, mu, var_part_ERSS, Y_cov){ ###Update mixture weights -update_weights_em <- function(x){ +update_weights_em <- function(x, lambda){ w <- colSums(x) + w <- w + lambda - 1 w <- w/sum(w) return(w) } - ###Impute/update missing Y impute_missing_Y_R <- function(Y, mu, Vinv, miss, non_miss){ n <- nrow(Y) diff --git a/man/mr.mash.Rd b/man/mr.mash.Rd index bcd2a65..22707fa 100644 --- a/man/mr.mash.Rd +++ b/man/mr.mash.Rd @@ -11,11 +11,13 @@ mr.mash( w0 = rep(1/(length(S0)), length(S0)), V = NULL, mu1_init = matrix(0, nrow = ncol(X), ncol = ncol(Y)), - tol = 0.0001, + tol = 1e-04, convergence_criterion = c("mu1", "ELBO"), max_iter = 5000, update_w0 = TRUE, update_w0_method = "EM", + w0_penalty = rep(1, length(S0)), + update_w0_max_iter = Inf, w0_threshold = 0, compute_ELBO = TRUE, standardize = TRUE, @@ -58,6 +60,12 @@ algorithm.} \item{update_w0_method}{Method to update prior weights. Only EM is currently supported.} +\item{w0_penalty}{K-vector of penalty terms (>=1) for each +mixture component. Default is all components are unpenalized.} + +\item{update_w0_max_iter}{Maximum number of iterations for the update +of w0.} + \item{w0_threshold}{Drop mixture components with weight less than this value. Components are dropped at each iteration after 15 initial iterations. This is done to prevent from dropping some poetentially important diff --git a/man/mr.mash.rss.Rd b/man/mr.mash.rss.Rd index 3691336..c32e530 100644 --- a/man/mr.mash.rss.Rd +++ b/man/mr.mash.rss.Rd @@ -16,11 +16,13 @@ mr.mash.rss( w0 = rep(1/(length(S0)), length(S0)), V, mu1_init = NULL, - tol = 0.0001, + tol = 1e-04, convergence_criterion = c("mu1", "ELBO"), max_iter = 5000, update_w0 = TRUE, update_w0_method = "EM", + w0_penalty = rep(1, length(S0)), + update_w0_max_iter = Inf, w0_threshold = 0, compute_ELBO = TRUE, standardize = FALSE, @@ -78,9 +80,15 @@ algorithm.} \item{update_w0_method}{Method to update prior weights. Only EM is currently supported.} +\item{w0_penalty}{K-vector of penalty terms (>=1) for each +mixture component. Default is all components are unpenalized.} + +\item{update_w0_max_iter}{Maximum number of iterations for the update +of w0.} + \item{w0_threshold}{Drop mixture components with weight less than this value. Components are dropped at each iteration after 15 initial iterations. -This is done to prevent from dropping some poetentially important +This is done to prevent from dropping some potentially important components prematurely.} \item{compute_ELBO}{If \code{TRUE}, ELBO is computed.} diff --git a/src/mr_mash_updates.cpp b/src/mr_mash_updates.cpp index f27b143..dbd9543 100644 --- a/src/mr_mash_updates.cpp +++ b/src/mr_mash_updates.cpp @@ -238,14 +238,14 @@ void impute_missing_Y (mat& Y, const mat& mu, const mat& Vinv, Y_cov += Y_cov_i; // Compute mean - vec Y_i = Y.row(i); - vec mu_i_miss = mu.row(i); + rowvec Y_i = Y.row(i); + rowvec mu_i_miss = mu.row(i); mu_i_miss = mu_i_miss.elem(miss_i_idx); - vec mu_i_non_miss = mu.row(i); - mu_i_non_miss = mu_i_non_miss.elem(non_miss_i_idx); - vec Y_i_non_miss = Y.row(i); - Y_i_non_miss = Y_i_non_miss.elem(non_miss_i_idx); - Y_i.elem(miss_i_idx) = mu_i_miss - Y_cov_mm * Vinv_mo * (Y_i_non_miss - mu_i_non_miss); + rowvec mu_i_non_miss = mu.row(i); + mu_i_non_miss = trans(mu_i_non_miss.elem(non_miss_i_idx)); + rowvec Y_i_non_miss = Y.row(i); + Y_i_non_miss = trans(Y_i_non_miss.elem(non_miss_i_idx)); + Y_i.elem(miss_i_idx) = mu_i_miss - Y_cov_mm * Vinv_mo * trans(Y_i_non_miss - mu_i_non_miss); Y.row(i) = Y_i;