-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ Imports: | |
laeken, | ||
ranger, | ||
MASS, | ||
xgboost, | ||
data.table(>= 1.9.4) | ||
Suggests: | ||
dplyr, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
#' Xgboost Imputation | ||
#' | ||
#' Impute missing values based on a random forest model using [xgboost::xgboost()] | ||
#' @param formula model formula for the imputation | ||
#' @param data A `data.frame` containing the data | ||
#' @param imp_var `TRUE`/`FALSE` if a `TRUE`/`FALSE` variables for each imputed | ||
#' variable should be created show the imputation status | ||
#' @param imp_suffix suffix used for TF imputation variables | ||
#' @param verbose Show the number of observations used for training | ||
#' and evaluating the RF-Model. This parameter is also passed down to | ||
#' [xgboost::xgboost()] to show computation status. | ||
#' @param ... Arguments passed to [xgboost::xgboost()] | ||
#' @param nrounds max number of boosting iterations, | ||
#' argument passed to [xgboost::xgboost()] | ||
#' @param objective objective for xgboost, | ||
#' argument passed to [xgboost::xgboost()] | ||
#' @return the imputed data set. | ||
#' @family imputation methods | ||
#' @examples | ||
#' data(sleep) | ||
#' xgboostImpute(Dream~BodyWgt+BrainWgt,data=sleep) | ||
#' xgboostImpute(Dream+NonD~BodyWgt+BrainWgt,data=sleep) | ||
#' xgboostImpute(Dream+NonD+Gest~BodyWgt+BrainWgt,data=sleep) | ||
#' | ||
#' sleepx <- sleep | ||
#' sleepx$Pred <- as.factor(LETTERS[sleepx$Pred]) | ||
#' sleepx$Pred[1] <- NA | ||
#' xgboostImpute(Pred~BodyWgt+BrainWgt,data=sleepx) | ||
#' @export | ||
xgboostImpute <- function(formula, data, imp_var = TRUE, | ||
imp_suffix = "imp", verbose = FALSE, | ||
nrounds=100, objective=NULL, | ||
...){ | ||
check_data(data) | ||
formchar <- as.character(formula) | ||
lhs <- gsub(" ", "", strsplit(formchar[2], "\\+")[[1]]) | ||
rhs <- formchar[3] | ||
rhs2 <- gsub(" ", "", strsplit(rhs, "\\+")[[1]]) | ||
#Missings in RHS variables | ||
rhs_na <- apply(subset(data, select = rhs2), 1, function(x) any(is.na(x))) | ||
#objective should be a vector of lenght equal to the lhs variables | ||
if(!is.null(objective)){ | ||
stopifnot(length(objective)!=length(lhs)) | ||
} | ||
for (lhsV in lhs) { | ||
form <- as.formula(paste(lhsV, "~", rhs,"-1")) | ||
# formula without left side for prediction | ||
formPred <- as.formula(paste( "~", rhs,"-1")) | ||
lhs_vector <- data[[lhsV]] | ||
num_class <- NULL | ||
if (!any(is.na(lhs_vector))) { | ||
cat(paste0("No missings in ", lhsV, ".\n")) | ||
} else { | ||
lhs_na <- is.na(lhs_vector) | ||
if (verbose) | ||
message("Training model for ", lhsV, " on ", sum(!rhs_na & !lhs_na), " observations") | ||
dattmp <- subset(data, !rhs_na & !lhs_na) | ||
labtmp <- dattmp[[lhsV]] | ||
currentClass <- NULL | ||
if(inherits(labtmp,"factor")){ | ||
currentClass <- "factor" | ||
labtmp <- as.integer(labtmp)-1 | ||
if(length(unique(labtmp))==2){ | ||
objective <- "binary:logistic" | ||
}else if(length(unique(labtmp))>2){ | ||
objective <- "multi:softmax" | ||
num_class <- max(labtmp)+1 | ||
} | ||
|
||
}else if(inherits(labtmp,"integer")){ | ||
currentClass <- "integer" | ||
if(length(unique(labtmp))==2){ | ||
lvlsInt <- unique(labtmp) | ||
labtmp <- match(labtmp,lvlsInt)-1 | ||
warning("binary factor detected but not probproperlyably stored as factor.") | ||
objective <- "binary:logistic" | ||
}else{ | ||
objective <- "count:poisson"## Todo: this might not be wise as default | ||
} | ||
}else if(inherits(labtmp,"numeric")){ | ||
currentClass <- "numeric" | ||
if(length(unique(labtmp))==2){ | ||
lvlsInt <- unique(labtmp) | ||
labtmp <- match(labtmp,lvlsInt)-1 | ||
warning("binary factor detected but not properly stored as factor.") | ||
objective <- "binary:logistic" | ||
}else{ | ||
objective <- "reg:squarederror" | ||
} | ||
} | ||
|
||
|
||
mm <- model.matrix(form,dattmp) | ||
if(!is.null(num_class)){ | ||
mod <- xgboost::xgboost(data = mm, label = labtmp, | ||
nrounds=nrounds, objective=objective, num_class = num_class, verbose = verbose,...) | ||
}else{ | ||
mod <- xgboost::xgboost(data = mm, label = labtmp, | ||
nrounds=nrounds, objective=objective, verbose = verbose,...) | ||
} | ||
|
||
if (verbose) | ||
message("Evaluating model for ", lhsV, " on ", sum(!rhs_na & lhs_na), " observations") | ||
predictions <- | ||
predict(mod, model.matrix(formPred,subset(data, !rhs_na & lhs_na))) | ||
if(currentClass=="factor"){ | ||
if(is.null(num_class)){ | ||
data[!rhs_na & lhs_na, lhsV] <- levels(dattmp[,lhsV])[as.numeric(predictions>.5)+1] | ||
}else{ | ||
data[!rhs_na & lhs_na, lhsV] <- levels(dattmp[,lhsV])[predictions+1] | ||
} | ||
}else if(currentClass%in%c("numeric","integer")&objective=="binary:logistic"){ | ||
data[!rhs_na & lhs_na, lhsV] <- lvlsInt[as.numeric(predictions>.5)+1] | ||
}else{ | ||
data[!rhs_na & lhs_na, lhsV] <- predictions | ||
} | ||
|
||
} | ||
|
||
if (imp_var) { | ||
if (imp_var %in% colnames(data)) { | ||
data[, paste(lhsV, "_", imp_suffix, sep = "")] <- as.logical(data[, paste(lhsV, "_", imp_suffix, sep = "")]) | ||
warning(paste("The following TRUE/FALSE imputation status variables will be updated:", | ||
paste(lhsV, "_", imp_suffix, sep = ""))) | ||
} else { | ||
data$NEWIMPTFVARIABLE <- is.na(lhs_vector) | ||
colnames(data)[ncol(data)] <- paste(lhsV, "_", imp_suffix, sep = "") | ||
} | ||
} | ||
} | ||
data | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
library(VIM) | ||
message("matchImpute general") | ||
setna <- function(d,i,col=2){ | ||
d[i,col] <- NA | ||
d | ||
} | ||
d <- data.frame(x=LETTERS[1:6],y=as.double(1:6),z=as.double(1:6), | ||
w=ordered(LETTERS[1:6]), stringsAsFactors = FALSE) | ||
dorig <- rbind(d,d) | ||
# minimal example with one match var | ||
d1 <- matchImpute(setna(dorig,7:12,1)[,1:2],match_var = "y", variable="x") | ||
expect_identical(d1$x[d1$x_imp],d1$x[!d1$x_imp]) | ||
|
||
d1b <- matchImpute(setna(dorig,7:12,1)[,1:2],match_var = "y", variable="x", imp_var = FALSE) | ||
expect_identical(d1b$x[d1$x_imp],d1b$x[!d1$x_imp]) | ||
expect_false("x_imp" %in% colnames(d1b)) | ||
expect_true("x_imp" %in% colnames(d1)) | ||
|
||
|
||
# all missing in x -> error | ||
expect_error(matchImpute(setna(dorig,1:12,1)[,1:2],match_var = "y", variable="x")) | ||
|
||
|
||
# example with two match vars | ||
d1 <- matchImpute(setna(dorig,7:12,1)[,1:3],match_var = c("y","z"), variable="x") | ||
expect_identical(d1$x[d1$x_imp],d1$x[!d1$x_imp]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
library(VIM) | ||
set.seed(104) | ||
x <- rnorm(100) | ||
df <- data.frame( | ||
y = x + rnorm(100, sd = .01), | ||
x = x, | ||
fac = as.factor(x >= 0) | ||
) | ||
|
||
max_dist <- function(x, y) { | ||
max(abs(x - y)) | ||
} | ||
|
||
df$y[1:3] <- NA | ||
df$fac[3:5] <- NA | ||
df$binNum <- as.integer(df$fac)+17 | ||
df$binInt <- as.integer(df$fac)+17L | ||
# xgboostImpute accuracy", { | ||
df.out <- xgboostImpute(y ~ x, df) | ||
expect_true( | ||
max_dist(df.out$y, df$x)< | ||
0.06 | ||
) | ||
|
||
# xgboostImpute should do nothing for no missings", { | ||
df.out <- xgboostImpute(x ~ y, df) | ||
expect_identical(df.out$x, df$x) | ||
# | ||
|
||
# factor response predicted accurately", { | ||
df.out <- xgboostImpute(fac ~ x, df) | ||
expect_identical(df.out$fac, as.factor(df$x >= 0)) | ||
# | ||
|
||
# interger binary response predicted accurately", { | ||
expect_warning(df.out <- xgboostImpute(binInt ~ x, df)) | ||
expect_identical(df.out$binInt==19, df$x >= 0) | ||
# | ||
# numeric binary response predicted accurately", { | ||
expect_warning(df.out <- xgboostImpute(binNum ~ x, df)) | ||
expect_identical(df.out$binNum==19, df$x >= 0) | ||
# | ||
# factor regressor used reasonably", { | ||
df2 <- df | ||
df2$x[1:10] <- NA | ||
df.out <- xgboostImpute(x ~ fac, df2) | ||
expect_identical(as.factor(df.out$x >= 0), df$fac) | ||
# | ||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.