Skip to content

Commit

Permalink
Added cervical cancer example and improved minor things
Browse files Browse the repository at this point in the history
  • Loading branch information
Goerke authored and Goerke committed Aug 2, 2019
1 parent 44d99a6 commit 75d355b
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 81 deletions.
2 changes: 1 addition & 1 deletion R/RPerturbFun_tabular_featureless_disc.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ perturbate.tabular.featurelessDisc <-
next
}
# Check bin of instance
providedBin = provideBin.numeric(instance[i], bin)
providedBin = provideBin(instance[i], bin)
binsNo = 1:(length(bins[[i]]$cuts) + 1)

instance[, i] = sample(c(providedBin, binsNo[-providedBin]), 1,
Expand Down
3 changes: 1 addition & 2 deletions R/anchors_dataframe.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ anchors.data.frame <-
if (is.null(bins)) {
bins <- create.empty.discretization(predictorCount)
}
validate.bins(bins, predictorCount)
explainer$bins <- bins
explainer$bins <- validate.bins(bins, predictorCount)

if (is.null(perturbator))
perturbator <- makePerturbFun("tabular.featureless")
Expand Down
2 changes: 1 addition & 1 deletion R/coverageFun.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ calculate.coverage <-
colnames(reducedPerturbations) = features

for (i in 1:ncol(reducedPerturbations)) {
featureVec[i] = provideBin.numeric(featureVec[i], bins[[features[i]]])
featureVec[i] = provideBin(featureVec[i], bins[[features[i]]])
}

matchingRows = nrow(suppressMessages(plyr::match_df(reducedPerturbations, featureVec)))
Expand Down
104 changes: 73 additions & 31 deletions R/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,46 @@ validate.bins <- function(bins, length) {
checkmate::expect_list(bins)
if (length(bins) != length)
stop("There needs to be one bin defined for each element")
for (bin in bins) {
for (i in 1:length(bins)) {
bin <- bins[[i]]
if (length(which(!(
names(bin) %in% c("doDiscretize", "numeric", "classes", "cuts", "right")
))) > 0)
stop("Invalid bin arguments")
if (!is.null(bin$doDiscretize) && bin$doDiscretize == F)
next

if (!is.null(bin$numeric) && bin$numeric == T) {
if (is.null(bin$numeric)) {
if ((is.null(bin$cuts) && is.null(bin$classes)) ||
(!is.null(bin$cuts) && !is.null(bin$classes)))
stop("Either classes or cuts have to be provided")

if (!is.null(bin$cuts))
bin$numeric <- T
if (!is.null(bin$classes))
bin$numeric <- F
}

if (is.null(bin$right))
bin$right <- F

if (bin$numeric == T) {
checkmate::assert_vector(bin$cuts)
checkmate::assert_null(bin$classes)
} else {
checkmate::assert_list(bin$classes)
checkmate::assert_null(bin$cuts)
checkmate::assert_null(bin$right)
}

bins[[i]] <- bin
}

return(bins)
}


provideBin.numeric <- function(value, bin) {
provideBin <- function(value, bin) {
# If discretization is disabled return value
if (!is.null(bin$doDiscretize) && !bin$doDiscretize) {
return(value)
Expand Down Expand Up @@ -69,45 +88,68 @@ provideBin.numeric <- function(value, bin) {
}


plotExplanations <- function(explanations, featureNames){


d = matrix(rep(0,length(featureNames)*length(unique(explanations[,"case"]))), ncol=length(featureNames))
plotExplanations <- function(explanations, featureNames) {
d = matrix(rep(0, length(featureNames) * length(unique(explanations[, "case"]))), ncol =
length(featureNames))

colnames(d) = featureNames

rownames(d) = unique(explanations[,"case"])

bins=as.data.frame(d)
sapply(colnames(d), function(featureName){
cases = unique(explanations[,"case"])
sapply(cases, function(case, featureName){
if(featureName %in% explanations[explanations[,"case"]==case, "feature"]){
d[case, featureName]<<- explanations[explanations[,"case"]==case & explanations[,"feature"]==featureName, "feature_weight"]
bins[case, featureName]<<- explanations[explanations[,"case"]==case & explanations[,"feature"]==featureName, "feature_desc"]
rownames(d) = unique(explanations[, "case"])

bins = as.data.frame(d)
sapply(colnames(d), function(featureName) {
cases = unique(explanations[, "case"])
sapply(cases, function(case, featureName) {
if (featureName %in% explanations[explanations[, "case"] == case, "feature"]) {
d[case, featureName] <<-
explanations[explanations[, "case"] == case &
explanations[, "feature"] == featureName, "feature_weight"]
bins[case, featureName] <<-
explanations[explanations[, "case"] == case &
explanations[, "feature"] == featureName, "feature_desc"]
}
}, featureName)
})

par(mfrow=c(nrow(d), 1), mar=c(5, 4, 4, 7) + 0.1)
colors=brewer.pal(n = 5, name = 'Blues')
cuts=seq(0.2,1,0.2)
r=sapply(1:nrow(d), function(i){
xlab=""
if(i==nrow(d)){
xlab="Features"
}
colorBorders=sapply(1:length(d[i,]), function(x) {
return(min(which(cuts>=d[i,x])))
par(mfrow = c(nrow(d), 1), mar = c(5, 4, 4, 7) + 0.1)
colors = brewer.pal(n = 5, name = 'Blues')
cuts = seq(0.2, 1, 0.2)
r = sapply(1:nrow(d), function(i) {
xlab = ""
if (i == nrow(d)) {
xlab = "Features"
}
)
p<-barplot(ifelse(d[i,]==0, NA, d[i,]), axes=F, ylab=paste("Instance",i), xlab=xlab, names.arg=colnames(d), ylim=c(0,1), col=colors[colorBorders])
text(p, 0, ifelse(bins[i,]==0, "", substr(bins[i,], start=nchar(colnames(d))+1, stop=100000)), cex=0.7, pos=3)
colorBorders = sapply(1:length(d[i, ]), function(x) {
return(min(which(cuts >= d[i, x])))
})
p <-
barplot(
ifelse(d[i, ] == 0, NA, d[i, ]),
axes = F,
ylab = paste("Instance", i),
xlab = xlab,
names.arg = colnames(d),
ylim = c(0, 1),
col = colors[colorBorders]
)
text(p,
0,
ifelse(bins[i, ] == 0, "", substr(
bins[i, ], start = nchar(colnames(d)) + 1, stop = 100000
)),
cex = 0.7,
pos = 3)
})
legend("bottomright",legend=seq(0.2,1,0.2), fill=brewer.pal(n = 5, name = 'Blues'), xpd=TRUE, inset=c(-0.1,0))
legend(
"bottomright",
legend = seq(0.2, 1, 0.2),
fill = brewer.pal(n = 5, name = 'Blues'),
xpd = TRUE,
inset = c(-0.1, 0)
)
}

buildDescription.numeric <- function(bin, cuts, right) {
buildDescription <- function(bin, cuts, right) {
desc = ""
if (bin == 1) {
if (right) {
Expand Down
4 changes: 2 additions & 2 deletions R/processRules.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ getAddedCoverage <- function(candidates, feature, instance, dataset){
getFeatureText <- function(candidates, feature, instance, dataset, bins){
bin <- bins[[feature+1]]
if (candidates$addedFeature == feature){
providedBin = provideBin.numeric(instance[feature+1], bin)
providedBin = provideBin(instance[feature+1], bin)
if (!is.null(bin$doDiscretize) && !bin$doDiscretize) {
featureDesc = paste(colnames(dataset)[feature+1], "=", providedBin)
}
else {
featureDesc = paste(colnames(dataset)[feature+1], "IN", buildDescription.numeric(providedBin, bin$cuts, bin$right))
featureDesc = paste(colnames(dataset)[feature+1], "IN", buildDescription(providedBin, bin$cuts, bin$right))
}
names(featureDesc) = feature+1
return(featureDesc)
Expand Down
97 changes: 54 additions & 43 deletions inst/examples/mlr_cervical.R
Original file line number Diff line number Diff line change
@@ -1,51 +1,62 @@
library(mlr)

load(paste0(getwd(), "/inst/examples/ExampleCancer/cervical.RData"))

# our goal is to predict whether individuum has cancer
task = makeClassifTask(data = cervical, target = "Biopsy", id = "Biopsy")

# setting up a learner
lrn.rpart = makeLearner("classif.rpart")

# train the learner on the training set
model = mlr::train(learner = lrn.rpart, task = task)


# Setting up a perturbation function. As we want explain a tabular instance (an observation in our dataset iris), we stick to a featureless tabular perturbation function
perturbator = makePerturbFun("tabular.featureless")

# discretizing the dataset
## TODO: add better discretes
discCervical = cervical
discCervical[,"Smokes"] = NA
discCervical[,"Smokes..years."] = NA
discCervical[,"Hormonal.Contraceptives"] = NA
discCervical[,"Hormonal.Contraceptives..years."] = NA
discCervical[,"IUD"] = NA
discCervical[,"IUD..years."] = NA
discCervical[,"STDs"] = NA
discCervical[,"STDs..number."] = NA
discCervical[,"STDs..Number.of.diagnosis"] = NA
discCervical[,"STDs..Time.since.first.diagnosis"] = NA
discCervical[,"STDs..Time.since.last.diagnosis"] = NA
discCervical = arules::discretizeDF(discCervical)
discCervical[,"Smokes"] = arules::discretize(cervical[,"Smokes"], breaks =1 )
discCervical[,"Smokes..years."] = arules::discretize(cervical[,"Smokes..years."], breaks =1)
discCervical[,"Hormonal.Contraceptives"] = arules::discretize(cervical[,"Hormonal.Contraceptives"], breaks =1)
discCervical[,"Hormonal.Contraceptives..years."] = arules::discretize(cervical[,"Hormonal.Contraceptives..years."], breaks =2)
discCervical[,"IUD"]= arules::discretize(cervical[,"IUD"], breaks =1 )
discCervical[,"IUD..years."]= arules::discretize(cervical[,"IUD..years."], breaks =1)
discCervical[,"STDs"]= arules::discretize(cervical[,"STDs"], breaks =1)
discCervical[,"STDs..number."]= arules::discretize(cervical[,"STDs..number."], breaks =1)
discCervical[,"STDs..Number.of.diagnosis"]= arules::discretize(cervical[,"STDs..Number.of.diagnosis"], breaks =1)
discCervical[,"STDs..Time.since.first.diagnosis"]= arules::discretize(cervical[,"STDs..Time.since.first.diagnosis"], breaks =1)
discCervical[,"STDs..Time.since.last.diagnosis"]= arules::discretize(cervical[,"STDs..Time.since.last.diagnosis"], breaks =1)
load("inst/examples/ExampleCancer/cervical.RData")

cervical_label_cancer = cervical[cervical$Biopsy == "Cancer",]
cervical_label_healthy = cervical[cervical $Biopsy == "Healthy",]
cervical_label_healthy = cervical_label_healthy[sample(1:nrow(cervical_label_healthy), nrow(cervical_label_cancer)), ]
cervical = rbind(cervical_label_cancer, cervical_label_healthy)

cervical.task = makeClassifTask(data = cervical, target = "Biopsy")
model = mlr::train(mlr::makeLearner(cl = 'classif.rpart', id = 'cervical-rf', predict.type = 'prob'), cervical.task)

# Visualize
rpart.plot::rpart.plot(getLearnerModel(model))

bins <- list()
for (i in 1:(ncol(cervical)-1)) {
bins[[i]] <- list()
bins[[i]]$doDiscretize <- T
#bins[[i]]$numeric <- T
#bins[[i]]$right =
}

# Age
bins[[1]]$cuts <- c(15, 25, 35, 50, 60)
# Number.of.sexual.partners
bins[[2]]$cuts <- arules::discretize(cervical[, 2], breaks = 2, onlycuts = T)
# First.sexual.intercourse
bins[[3]]$cuts <- arules::discretize(cervical[, 3], breaks = 4, onlycuts = T)
# Num.of.pregnancies
bins[[4]]$cuts <- c(0, 1, 2, 4)
# Smokes
bins[[5]]$doDiscretize <- F
# Smokes..years.
bins[[6]]$cuts <- c(0, 2, 5, 10)
# Hormonal.Contraceptives
bins[[7]]$doDiscretize <- F
# Hormonal.Contraceptives..years.
bins[[8]]$cuts <- arules::discretize(cervical[, 8], method = "cluster", breaks = 4, onlycuts = T)
# IUD
bins[[9]]$doDiscretize <- F
# IUD..years.
bins[[10]]$cuts <- arules::discretize(cervical[, 10], method = "cluster", breaks = 4, onlycuts = T)
# STDs
bins[[11]]$doDiscretize <- F
# STDs..number.
bins[[12]]$cuts <- c(0, 1)
# STDs..Number.of.diagnosis
bins[[13]]$cuts <- c(0, 1)
# STDs..Time.since.first.diagnosis
bins[[14]]$cuts <- c(0, 3)
# STDs..Time.since.last.diagnosis
bins[[15]]$cuts <- c(0, 3)


# Explain model with anchors
explainer = anchors(cervical, model, perturbator, discX = discCervical)
explainer = anchors(cervical, model, bins = bins)

explanations = explain(cervical[1:2,], explainer)
explanations = explain(cervical[3,], explainer)

printExplanations(explainer, explanations)

1 change: 0 additions & 1 deletion inst/examples/mlr_iris.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ rpart.plot::rpart.plot(getLearnerModel(model))
bins = list()
r = sapply(1:(ncol(iris) - 1), function(x) {
bin <<- list()
bin$numeric <<- T
cuts = arules::discretize(iris[, x], onlycuts = T)
bin$cuts <<- cuts[2:(length(cuts) - 1)]
bin$right <<- F
Expand Down

0 comments on commit 75d355b

Please sign in to comment.