Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A few updates #6

Merged
merged 12 commits into from
Sep 18, 2024
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Encoding: UTF-8
Type: Package
Package: mr.mash.alpha
Version: 0.3.29
Date: 2024-06-03
Version: 0.3.34
Date: 2024-09-15
Title: Multiple Regression with Multivariate Adaptive Shrinkage
Description: Provides an implementation of methods for multivariate
multiple regression with adaptive shrinkage priors.
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ importFrom(mashr,mash_set_data)
importFrom(matrixStats,colMeans2)
importFrom(matrixStats,colSds)
importFrom(matrixStats,colVars)
importFrom(methods,as)
importFrom(mvtnorm,dmvnorm)
importFrom(mvtnorm,rmvnorm)
importFrom(parallel,makeCluster)
Expand Down
15 changes: 11 additions & 4 deletions R/mr_mash.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,17 @@
#'
#' @param update_w0_method Method to update prior weights. Only EM is
#' currently supported.
#'
#' @param w0_penalty K-vector of penalty terms (>=1) for each
#' mixture component. Default is all components are unpenalized.
#'
#' @param w0_threshold Drop mixture components with weight less than this value.
#' Components are dropped at each iteration after 15 initial iterations.
#' This is done to prevent from dropping some poetentially important
#' components prematurely.
#'
#' @param update_w0_max_iter Maximum number of iterations for the update
#' of w0.
#'
#' @param update_V if \code{TRUE}, residual covariance is updated.
#'
Expand Down Expand Up @@ -163,8 +169,8 @@
#'
mr.mash <- function(X, Y, S0, w0=rep(1/(length(S0)), length(S0)), V=NULL,
mu1_init=matrix(0, nrow=ncol(X), ncol=ncol(Y)), tol=1e-4, convergence_criterion=c("mu1", "ELBO"),
max_iter=5000, update_w0=TRUE, update_w0_method="EM",
w0_threshold=0, compute_ELBO=TRUE, standardize=TRUE, verbose=TRUE,
max_iter=5000, update_w0=TRUE, update_w0_method="EM", w0_penalty=rep(1, length(S0)),
update_w0_max_iter=Inf, w0_threshold=0, compute_ELBO=TRUE, standardize=TRUE, verbose=TRUE,
update_V=FALSE, update_V_method=c("full", "diagonal"), version=c("Rcpp", "R"), e=1e-8,
ca_update_order=c("consecutive", "decreasing_logBF", "increasing_logBF", "random"),
nthreads=as.integer(NA)) {
Expand Down Expand Up @@ -429,14 +435,15 @@ mr.mash <- function(X, Y, S0, w0=rep(1/(length(S0)), length(S0)), V=NULL,
}

##Update w0 if requested
if(update_w0){
w0 <- update_weights_em(w1_t)
if(update_w0 && t <= update_w0_max_iter){
w0 <- update_weights_em(w1_t, w0_penalty)

#Drop components with mixture weight <= w0_threshold
if(t>15 && any(w0 < w0_threshold)){
to_keep <- which(w0 >= w0_threshold)
w0 <- w0[to_keep]
w0 <- w0/sum(w0)
w0_penalty <- w0_penalty[to_keep]
S0 <- S0[to_keep]
if(length(to_keep) > 1){
comps <- filter_precomputed_quants(comps, to_keep, standardize, version)
Expand Down
19 changes: 13 additions & 6 deletions R/mr_mash_rss.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,16 @@
#'
#' @param update_w0_method Method to update prior weights. Only EM is
#' currently supported.
#'
#' @param update_w0_max_iter Maximum number of iterations for the update
#' of w0.
#'
#' @param w0_penalty K-vector of penalty terms (>=1) for each
#' mixture component. Default is all components are unpenalized.
#'
#' @param w0_threshold Drop mixture components with weight less than this value.
#' Components are dropped at each iteration after 15 initial iterations.
#' This is done to prevent from dropping some poetentially important
#' This is done to prevent from dropping some potentially important
#' components prematurely.
#'
#' @param update_V if \code{TRUE}, residual covariance is updated.
Expand Down Expand Up @@ -180,8 +186,8 @@
#'
mr.mash.rss <- function(Bhat, Shat, Z, R, covY, n, S0, w0=rep(1/(length(S0)), length(S0)), V,
mu1_init=NULL, tol=1e-4, convergence_criterion=c("mu1", "ELBO"),
max_iter=5000, update_w0=TRUE, update_w0_method="EM",
w0_threshold=0, compute_ELBO=TRUE, standardize=FALSE, verbose=TRUE,
max_iter=5000, update_w0=TRUE, update_w0_method="EM", w0_penalty=rep(1, length(S0)),
update_w0_max_iter=Inf, w0_threshold=0, compute_ELBO=TRUE, standardize=FALSE, verbose=TRUE,
update_V=FALSE, update_V_method=c("full", "diagonal"), version=c("Rcpp", "R"), e=1e-8,
ca_update_order=c("consecutive", "decreasing_logBF", "increasing_logBF", "random"),
X_colmeans=NULL, Y_colmeans=NULL, check_R=TRUE, R_tol=1e-08,
Expand Down Expand Up @@ -463,14 +469,15 @@ mr.mash.rss <- function(Bhat, Shat, Z, R, covY, n, S0, w0=rep(1/(length(S0)), le
}

##Update w0 if requested
if(update_w0){
w0 <- update_weights_em(w1_t)
if(update_w0 && t <= update_w0_max_iter){
w0 <- update_weights_em(w1_t, w0_penalty)

#Drop components with mixture weight <= w0_threshold
if(t>15 && any(w0 < w0_threshold)){
to_keep <- which(w0 >= w0_threshold)
w0 <- w0[to_keep]
w0 <- w0/sum(w0)
w0_penalty <- w0_penalty[to_keep]
S0 <- S0[to_keep]
if(length(to_keep) > 1){
comps <- filter_precomputed_quants(comps, to_keep, standardize, version)
Expand All @@ -483,7 +490,7 @@ mr.mash.rss <- function(Bhat, Shat, Z, R, covY, n, S0, w0=rep(1/(length(S0)), le
} else { #some other component is the only one left
stop("Only one component (different from the null) left. Consider lowering w0_threshold.")
}
}
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions R/mr_mash_updates.R
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,13 @@ update_V_fun <- function(Y, mu, var_part_ERSS, Y_cov){


###Update mixture weights
update_weights_em <- function(x){
update_weights_em <- function(x, lambda){
w <- colSums(x)
w <- w + lambda - 1
w <- w/sum(w)
return(w)
}


###Impute/update missing Y
impute_missing_Y_R <- function(Y, mu, Vinv, miss, non_miss){
n <- nrow(Y)
Expand Down
10 changes: 9 additions & 1 deletion man/mr.mash.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 10 additions & 2 deletions man/mr.mash.rss.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 7 additions & 7 deletions src/mr_mash_updates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,14 @@ void impute_missing_Y (mat& Y, const mat& mu, const mat& Vinv,
Y_cov += Y_cov_i;

// Compute mean
vec Y_i = Y.row(i);
vec mu_i_miss = mu.row(i);
rowvec Y_i = Y.row(i);
rowvec mu_i_miss = mu.row(i);
mu_i_miss = mu_i_miss.elem(miss_i_idx);
vec mu_i_non_miss = mu.row(i);
mu_i_non_miss = mu_i_non_miss.elem(non_miss_i_idx);
vec Y_i_non_miss = Y.row(i);
Y_i_non_miss = Y_i_non_miss.elem(non_miss_i_idx);
Y_i.elem(miss_i_idx) = mu_i_miss - Y_cov_mm * Vinv_mo * (Y_i_non_miss - mu_i_non_miss);
rowvec mu_i_non_miss = mu.row(i);
mu_i_non_miss = trans(mu_i_non_miss.elem(non_miss_i_idx));
rowvec Y_i_non_miss = Y.row(i);
Y_i_non_miss = trans(Y_i_non_miss.elem(non_miss_i_idx));
Y_i.elem(miss_i_idx) = mu_i_miss - Y_cov_mm * Vinv_mo * trans(Y_i_non_miss - mu_i_non_miss);

Y.row(i) = Y_i;

Expand Down
Loading