Skip to content

Commit

Permalink
Fixing layout, grammer and spelling errors. Adding error messages and…
Browse files Browse the repository at this point in the history
… test.
  • Loading branch information
katrinekirkeby committed May 15, 2019
1 parent 9713ed2 commit 17d0ccf
Show file tree
Hide file tree
Showing 65 changed files with 794 additions and 453 deletions.
6 changes: 2 additions & 4 deletions R/BIC_junction_tree.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' Calculates the BIC value
#'
#' @description Calculates the BIC value for a graphical model where the
#' graph has the given junction tree.
#' @description Calculates the BIC value for a graphical model from a junction tree
#' for the graph.
#'
#' @param cliques A list containing the cliques of the junction tree.
#' @param separators A list containing the separators of the junction
Expand Down Expand Up @@ -30,7 +30,6 @@
#' determining the number of free parameters in the model.
#'
#' @examples
#'
#' set.seed(43)
#' var1 <- c(sample(c(1, 2), 100, replace = TRUE))
#' var2 <- var1 + c(sample(c(1, 2), 100, replace = TRUE))
Expand Down Expand Up @@ -63,7 +62,6 @@
#' # smooth is used to deal with zero probabilities.
#' BIC_junction_tree(cliques, separators, data, smooth = 0.1)
#' @export
#'

BIC_junction_tree <- function(cliques, separators, data,
base_log = 2, ...){
Expand Down
48 changes: 41 additions & 7 deletions R/CL.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
#' an adjacency matrix, is acyclic.
#'
#' @param adj_matrix The adjacency matrix representing the graph.
#'
#' @details Notice that the function cannot cope with loops.
#' If the graph has loops, an error is returned.
#'
#' @return A logical value indicating whether the graph is acyclic.
#'
#' @author
#' Katrine Kirkeby, \email{enir_tak@@hotmail.com}
#'
#' Maria Knudsen, \email{mariaknudsen@@hotmail.dk}
#'
#' Ninna Vihrs, \email{ninnavihrs@@hotmail.dk}
#'
#' @examples
#' adj_matrix_cyclic <- matrix(c(0, 1, 1, 1,
#' 1, 0, 0, 1,
Expand All @@ -32,9 +36,29 @@
#' @export

is_acyclic <- function(adj_matrix){

if (! is.matrix(adj_matrix)){
stop("Argument must be a matrix.")
}

if (any(diag(adj_matrix) == 1)){
stop("The graph represented by the matrix contains loops.")
}

if (! is.numeric(adj_matrix)){
stop("Argument must be numeric.")
}

if (any(! c(adj_matrix) %in% 0:1)){
stop(paste("Argument must be an adjacency matrix for an unweighted graph.",
"Therefore all entries must be 0 or 1.", sep = " "))
}

if (! isSymmetric(adj_matrix)){
stop(paste("Only undirected graphs are supported so argument must be",
"symmetric. This includes that rownames must equal colnames.", sep = " "))
}

while (any(rowSums(adj_matrix) == 1 |
rowSums(adj_matrix) == 0)){
idx <- which(rowSums(adj_matrix) == 1 |
Expand Down Expand Up @@ -103,6 +127,7 @@ is_acyclic <- function(adj_matrix){

ChowLiu <- function(data, root = NULL, bayes_smooth = 0,
CPTs = TRUE, ...){

if (any(is.na(data))){
warning(paste("The data contains NA values.",
"Theese will be excluded from tables,",
Expand All @@ -115,13 +140,25 @@ ChowLiu <- function(data, root = NULL, bayes_smooth = 0,
if (! (is.data.frame(data) | is.matrix(data))) {
stop("data must be a data frame or a matrix.")
}

data <- as.data.frame(data)

if (! all(sapply(data, function(x){
is.character(x) | is.factor(x)
}
))){
stop("Some columns are not characters or factors.")
}

if (length(bayes_smooth) > 1){
stop("bayes_smooth must be a single non-negative value.")
}
else if (!is.numeric(bayes_smooth)) {
stop("bayes_smooth must be numeric.")
}
else if (bayes_smooth < 0){
stop("bayes_smooth must be a non-negative numeric value.")
}

# Calculating mutual information
nodes <- colnames(data)
Expand All @@ -146,7 +183,7 @@ ChowLiu <- function(data, root = NULL, bayes_smooth = 0,
MI_tab <- MI_tab[ord_idx, ]
rownames(MI_tab) <- NULL

# Construct skeleton for Chow-Liu tree
# Construct skeleton for Chow-Liu tree.
adj_matrix <- matrix(0, nrow = n_var, ncol = n_var)
rownames(adj_matrix) <- colnames(adj_matrix) <- nodes
i <- 1
Expand All @@ -167,7 +204,7 @@ ChowLiu <- function(data, root = NULL, bayes_smooth = 0,

skeleton_adj <- adj_matrix

# Determine DAG
# Determine DAG.
if (is.null(root)){
root <- sample(nodes, 1)
}
Expand All @@ -187,16 +224,13 @@ ChowLiu <- function(data, root = NULL, bayes_smooth = 0,
root <- nodes[kids_idx]
}

# Calculate conditional probability tables
# Calculate conditional probability tables.
if (CPTs){
CPTs <- CPT(adj_matrix_directed, data,
bayes_smooth = bayes_smooth)


}

return(list("skeleton_adj" = skeleton_adj,
"adj_DAG" = adj_matrix_directed,
"CPTs" = CPTs))

}
}
76 changes: 75 additions & 1 deletion R/CPT.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
isDAG <- function(adj_matrix){

while (any(rowSums(adj_matrix) == 0)) {
idx <- which(rowSums(adj_matrix) == 0)[1]
adj_matrix <- as.matrix(adj_matrix[- idx, - idx])
}

if (nrow(adj_matrix) == 0){
res <- TRUE
} else {
res <- FALSE
}
res
}

#' Estimate conditional probability tables
#'
#' @description Estimates the conditional probability tables for
Expand Down Expand Up @@ -28,6 +43,7 @@
#' var3 <- var1 + c(sample(c(0, 1), 50, replace = TRUE,
#' prob = c(0.9, 0.1)))
#' var4 <- c(sample(c(1, 2), 50, replace = TRUE))
#'
#' data <- data.frame("var1" = as.character(var1),
#' "var2" = as.character(var2),
#' "var3" = as.character(var3),
Expand All @@ -38,6 +54,9 @@
#' 1, 0, 0, 0,
#' 0, 1, 0, 0),
#' nrow = 4)
#'
#' rownames(adj_matrix_DAG) <- colnames(adj_matrix_DAG) <- names(data)
#'
#' CPT(adj_matrix_DAG, data)
#' CPT(adj_matrix_DAG, data, bayes_smooth = 1)
#' @export
Expand All @@ -53,13 +72,68 @@ CPT <- function(adj_matrix, data, bayes_smooth = 0){
sep = " "))
}

if (! (is.data.frame(data) | is.matrix(data))) {
stop("data must be a data frame or a matrix.")
}

data <- data.frame(data, stringsAsFactors = FALSE)

if (! all(sapply(data, function(x){
is.character(x) | is.factor(x)
}
))){
stop("Some columns are not characters or factors.")
}

if (! is.matrix(adj_matrix)){
stop("adj_matrix must be a matrix.")
}

if (any(diag(adj_matrix) == 1)){
stop("The graph represented by adj_matrix contains loops.")
}

if (! is.numeric(adj_matrix)){
stop("adj_matrix must be numeric.")
}

if (any(! c(adj_matrix) %in% 0:1)){
stop(paste("adj_matrix must be an adjacency matrix for an unweighted graph.",
"Therefore all entries must be 0 or 1.", sep = " "))
}

if (is.null(colnames(adj_matrix)) | is.null(rownames(adj_matrix))){
stop("adj_matrix must be named.")
}

if (any(colnames(adj_matrix) != rownames(adj_matrix))){
stop("Names of columns and rows in adj_matrix must be the same.")
}

if (length(setdiff(colnames(adj_matrix), names(data))) != 0){
stop("The names of adj_matrix must be variable names in data.")
}

if (! isDAG(adj_matrix)){
stop("adj_matrix is not a DAG.")
}

if (length(bayes_smooth) > 1){
stop("bayes_smooth must be a single non-negative value.")
}
else if (!is.numeric(bayes_smooth)) {
stop("bayes_smooth must be numeric.")
}
else if (bayes_smooth < 0){
stop("bayes_smooth must be a non-negative numeric value.")
}

nodes <- rownames(adj_matrix)
FUN <- function(node){
parents_idx <- which(adj_matrix[, node] == 1)
parents <- nodes[parents_idx]

tab <- table(data[, c(node, parents)]) + bayes_smooth

if (length(parents) == 0){
mar <- NULL
names(dimnames(tab)) <- node
Expand Down
23 changes: 19 additions & 4 deletions R/MI.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#' Calculates mutual information
#'
#' @description Calculate mutual information for two or three
#' categorical variables.
#'
#' @param x,y,z Vectors of class character or factor.
#' @param smooth Additional cell counts for bayesian estimation of
#' probabilities.
#' @param log_base The base of the logarithmic function to be used.
#'
#' @details
#'
#' The mutual information for two variables is calculated by the
#' formula \deqn{MI(x, y) = \sum P(x, y) log(P(x, y) / (P(x)P(y)))}
#' where the sum is over alle possible values of x and y.
Expand All @@ -16,37 +17,49 @@
#' formula \deqn{MI(x, y, z) = \sum P(x, y, z) log(P(x, y, z) / (P
#' (x)P(y)P(z)))} where the sum is over all possible values of x, y and
#' z.
#'
#' @return The mutual information given by a single numeric value.
#'
#' @author
#' Katrine Kirkeby, \email{enir_tak@@hotmail.com}
#'
#' Maria Knudsen, \email{mariaknudsen@@hotmail.dk}
#'
#' Ninna Vihrs, \email{ninnavihrs@@hotmail.dk}
#'
#' @importFrom Rdpack reprompt
#'
#' @references
#' \insertRef{TCJT}{tcherry}
#'
#' \insertRef{EKTS}{tcherry}
#'
#' @seealso
#' \code{\link{MIk}} for mutual information for k variables.
#'
#' @examples
#'
#' var1 <- c(sample(c(1, 2), 100, replace = TRUE))
#' var2 <- var1 + c(sample(c(1, 2), 100, replace = TRUE))
#' var3 <- c(sample(c(1, 2), 100, replace = TRUE))
#' var1 <- as.character(var1)
#' var2 <- as.character(var2)
#' var3 <- as.character(var3)
#'
#' MI2(var1, var2, smooth = 1)
#' MI2(var1, var2, smooth = 0.1, log_base = exp(1))
#'
#' MI3(var1, var2, var3, smooth = 1)
#' MI3(var1, var2, var3, smooth = 0.1, log_base = exp(1))
#' @export

MI2 <- function(x, y, smooth = 0, log_base = 2){

if (! all(sapply(list(x, y), function(x){
is.character(x) | is.factor(x)
}
))){
stop("x and y must be either characters or factors.")
}

if (length(smooth) > 1){
stop("smooth must be a single non-negative value.")
}
Expand Down Expand Up @@ -82,12 +95,14 @@ MI2 <- function(x, y, smooth = 0, log_base = 2){
#' @export

MI3 <- function(x, y, z, smooth = 0, log_base = 2){

if (! all(sapply(list(x, y, z), function(x){
is.character(x) | is.factor(x)
}
))){
stop("x, y and z must be either characters or factors.")
}

if (length(smooth) > 1){
stop("smooth must be a single non-negative value.")
}
Expand Down Expand Up @@ -120,4 +135,4 @@ MI3 <- function(x, y, z, smooth = 0, log_base = 2){
MI <- sum(prop_xyz * log(frac_prop_MI, base = log_base))

return(MI)
}
}
5 changes: 3 additions & 2 deletions R/MI_k.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#' "var5" = as.character(var5),
#' "var6" = as.character(var6),
#' "var7" = as.character(var7))
#'
#' MIk(c("var1", "var2", "var7"), data, smooth = 0.001)
#' @export

Expand All @@ -73,7 +74,7 @@ MIk <- function(variables, data, smooth = 0, log_base = 2){

data <- as.data.frame(data)

if (!all(variables %in% colnames(data))){
if (! all(variables %in% colnames(data))){
stop("All names in variables must be column names of data.")
}

Expand Down Expand Up @@ -124,4 +125,4 @@ MIk <- function(variables, data, smooth = 0, log_base = 2){
log(frac_prop_MI, base = log_base)))

return(MI)
}
}
Loading

0 comments on commit 17d0ccf

Please sign in to comment.