diff --git a/scripts/builtin/shapExplainer.dml b/scripts/builtin/shapExplainer.dml new file mode 100644 index 00000000000..626dc7da4c9 --- /dev/null +++ b/scripts/builtin/shapExplainer.dml @@ -0,0 +1,732 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Computes shapley values for multiple instances in parallel using antithetic permutation sampling. +# The resulting matrix phis holds the shapley values for each feature in the column given by the index of the feature in the sample. +# +# This method first creates two large matrices for masks and masked background data for all permutations and +# then runs in paralell on all instances in x. +# While the prepared matrices can become very large (2 * #features * #permuations * #n_samples * #features), +# the preparation of a row for the model call breaks down to a single element-wise multiplication of this mask with the row and +# an addition to the masked background data, since masks can be reused for each instance. +# +# INPUT: +# --------------------------------------------------------------------------------------- +# model_function The function of the model to be evaluated as a String. This function has to take a matrix of samples +# and return a vector of predictions. +# It might be usefull to wrap the model into a function the takes and returns the desired shapes and +# use this wrapper here. +# model_args Arguments in order for the model, if desired. This will be prepended by the created instances-matrix. +# x_instances Multiple instances as rows for which to compute the shapley values. +# X_bg The background dataset from which to pull the random samples to perform Monte Carlo integration. +# n_permutations The number of permutaions. Defaults to 10. Theoretical 1 should already be enough for models with up +# to second order interaction effects. +# n_samples Number of samples from X_bg used for marginalization. +# remove_non_var EXPERIMENTAL: If set, for every instance the varaince of each feature is checked against this feature in the +# background data. If it does not change, we do not run any model cals for it. +# seed A seed, in case the sampling has to be deterministic. +# verbose A boolean to enable logging of each step of the function. +# --------------------------------------------------------------------------------------- +# +# OUTPUT: +# ----------------------------------------------------------------------------- +# S Matrix holding the shapley values along the cols, one row per instance. +# expected Double holding the average prediction of all instances. +# ----------------------------------------------------------------------------- +s_shapExplainer = function(String model_function, list[unknown] model_args, Matrix[Double] x_instances, + Matrix[Double] X_bg, Integer n_permutations = 10, Integer n_samples = 100, Integer remove_non_var=0, + Matrix[Double] partitions=as.matrix(-1), Integer seed = -1, Integer verbose = 0) + return (Matrix[Double] row_phis, Double expected) +{ + u_printShapMessage("Parallel Permutation Explainer for "+nrow(x_instances)+" rows.", verbose) + u_printShapMessage("Number of Features: "+ncol(x_instances), verbose ) + total_preds=ncol(x_instances)*2*n_permutations*n_samples*nrow(x_instances) + u_printShapMessage("Number of predictions: "+toString(total_preds)+" in "+nrow(x_instances)+ + " parallel cals.", verbose ) + + #start with all features + features=u_range(1, ncol(x_instances)) + + #handle partitions + if(sum(partitions) != -1){ + if(remove_non_var != 0){ + stop("shapley_permutations_by_row:ERROR: Can't use n_non_varying_inds and partitions at the same time.") + } + features=removePartitionsFromFeatures(features, partitions) + reduced_total_preds=ncol(features)*2*n_permutations*n_samples*nrow(x_instances) + u_printShapMessage("Using Partitions reduces number of features to "+ncol(features)+".", verbose ) + u_printShapMessage("Total number of predictions reduced by "+(total_preds-reduced_total_preds)/total_preds+" to "+reduced_total_preds+".", verbose ) + } + + #lengths and offsets + total_features = ncol(x_instances) + perm_length = ncol(features) + full_mask_offset = perm_length * 2 * n_samples + n_partition_features = total_features - perm_length + + #sample from X_bg + u_printShapMessage("Sampling from X_bg", verbose ) + # could use new samples for each permutation by sampling n_samples*n_permutations + X_bg_samples = u_sample_with_potential_replace(X_bg=X_bg, samples=n_samples, seed=seed ) + row_phis = matrix(0, rows=nrow(x_instances), cols=total_features) + expected_m = matrix(0, rows=nrow(x_instances), cols=1) + + #prepare masks for all permutations, since it stays the same for every row + u_printShapMessage("Preparing reusable intermediate masks.", verbose ) + permutations = matrix(0, rows=n_permutations, cols=perm_length) + masks_for_permutations = matrix(0, rows=perm_length*2*n_permutations*n_samples, cols=total_features) + + parfor (i in 1:n_permutations, check=0){ + #shuffle features to get permutation + permutations[i] = t(u_shuffle(t(features))) + perm_mask = prepare_mask_for_permutation(permutation=permutations[i], partitions=partitions) + + offset_masks = (i-1) * full_mask_offset + 1 + masks_for_permutations[offset_masks:offset_masks+full_mask_offset-1]=prepare_full_mask(perm_mask, n_samples) + } + + #replicate background and mask it, since it also can stay the same for every row + # could use new samples for each permutation by sampling n_samples*n_permutations and telling this function about it + masked_bg_for_permutations = prepare_masked_X_bg(masks_for_permutations, X_bg_samples, 0) + u_printShapMessage("Computing phis in parallel.", verbose ) + + #enable spark execution for parfor if desired + #TODO allow spark mode via parameter? + #parfor (i in 1:nrow(x_instances), opt=CONSTRAINED, mode=REMOTE_SPARK){ + + parfor (i in 1:nrow(x_instances)){ + if(remove_non_var == 1){ + # try to remove inds that do not vary from the background + non_var_inds = get_non_varying_inds(x_instances[i], X_bg_samples) + # only remove if more than 2 features remain, less then two breaks removal procedure + if (ncol(x_instances) > length(non_var_inds)+2){ + #remove samples and masks for non varying features + [i_masks_for_permutations, i_masked_bg_for_permutations] = remove_inds(masks_for_permutations, masked_bg_for_permutations, permutations, non_var_inds, n_samples) + }else{ + # we would remove all but two features, whichs breaks the removal algorithm + non_var_inds = as.matrix(-1) + i_masks_for_permutations = masks_for_permutations + i_masked_bg_for_permutations = masked_bg_for_permutations + } + } else { + non_var_inds = as.matrix(-1) + i_masks_for_permutations = masks_for_permutations + i_masked_bg_for_permutations = masked_bg_for_permutations + } + + #apply masks and bg data for all permutations at once + X_test = apply_full_mask(x_instances[i], i_masks_for_permutations, i_masked_bg_for_permutations) + + #generate args for call to model + X_arg = append(list(X=X_test), model_args) + + #call model + P = eval(model_function, X_arg) + + #compute means, deviding n_rows by n_samples + P = compute_means_from_predictions(P=P, n_samples=n_samples) + + #compute phis + [phis, e] = compute_phis_from_prediction_means(P=P, permutations=permutations, non_var_inds=non_var_inds, n_partition_features=n_partition_features) + expected_m[i] = e + + #compute phis for this row from all permutations + row_phis[i] = t(phis) + } + #compute expected of model from all rows + expected = mean(expected_m) +} + +# Computes which indices do not vary from the background. +# Uses the appraoch from numpy.isclose() and compares to the largest diff of each feature in the bg data. +# In the futere, more advanced techniques like using std-dev of bg data as a tollerance could be used. +# +# INPUT: +# ----------------------------------------------------------------------------- +# x One single instance. +# X_bg Background dataset. +# ----------------------------------------------------------------------------- +# OUTPUT: +# ----------------------------------------------------------------------------- +# non_varying_inds A row-vector with all the indices that do not vary from the background dataset. +# ----------------------------------------------------------------------------- +get_non_varying_inds = function(Matrix[Double] x, Matrix[Double] X_bg) +return (Matrix[Double] non_varying_inds){ + #from numpy.isclose but adapted to fit MSE of shap, which is within the same scale + rtol = 1e-04 + atol = 1e-05 + + # compute distance metrics + diff = colMaxs(abs(X_bg -x)) + rdist = atol + rtol * colMaxs(abs(X_bg)) + + non_varying_inds = (diff <= rdist) + # translate to indices + non_varying_inds = t(seq(1,ncol(x))) * non_varying_inds + # remove the ones that do vary + non_varying_inds = removeEmpty(target=non_varying_inds, margin="cols") +} + +# Prepares a boolean mask for removing features according to permutaion. +# The resulting matrix needs to be inflated to a sample set by using prepare_samples_from_mask() before calling the model. +# +# INPUT: +# --------------------------------------------------------------------------------------- +# permutation A single permutation of varying features. +# If using partitions, remove them beforhand by using removePartitionsFromFeatures() from the utils. +# n_non_varying_inds The number of feature that do not vary in the background data. +# Can be retrieved e.g. by looking at std.dev +# partitions Matrix with first elemnt of partition in first row and last element of partition in second row. +# Used to treat partitions as one feature when creating masks. Useful for one-hot-encoded features. +# --------------------------------------------------------------------------------------- +# +# OUTPUT: +# ----------------------------------------------------------------------------- +# mask Boolean mask. +# ----------------------------------------------------------------------------- +prepare_mask_for_permutation = function(Matrix[Double] permutation, Integer n_non_varying_inds=0, + Matrix[Double] partitions=as.matrix(-1)) +return (Matrix[Double] masks){ + if(sum(partitions)!=-1){ + #can't use n_non_varying_inds and partitions at the same time + if(n_non_varying_inds > 0){ + stop("shap-explainer::prepare_mask_for_permutation:ERROR: Can't use n_non_varying_inds and partitions at the same time.") + } + #number of features not in permutation is diff between start and end of partitions, since first feature remains in permutation + skip_inds = partitions[2,] - partitions[1,] + + #skip these inds by treating them as non varying + n_non_varying_inds = sum(skip_inds) + } + + #total number of features + perm_len = ncol(permutation)+n_non_varying_inds + if(n_non_varying_inds > 0){ + #prep full constructor with placeholders + mask_constructor = matrix(perm_len+1, rows=1, cols = perm_len) + mask_constructor[1,1:ncol(permutation)] = permutation + }else{ + mask_constructor=permutation + } + + perm_cols = ncol(mask_constructor) + + # we compute mask on reverse permutation wnd reverse it later to get desired shape + + # create row indicator vector ctable + perm_mask_rows = seq(1,perm_cols) + #TODO: col-vector and matrix mult? + perm_mask_rows = matrix(1, rows=perm_cols, cols=perm_cols) * perm_mask_rows + perm_mask_rows = lower.tri(target=perm_mask_rows, diag=TRUE, values=TRUE) + perm_mask_rows = removeEmpty(target=matrix(perm_mask_rows, rows=1, cols=length(perm_mask_rows)), margin="cols") + + # create column indicator for ctable + rev_permutation = t(rev(t(mask_constructor))) + #TODO: col-vector and matrix mult? + perm_mask_cols = matrix(1, rows=perm_cols, cols=perm_cols) * mask_constructor + perm_mask_cols = lower.tri(target=perm_mask_cols, diag=TRUE, values=TRUE) + perm_mask_cols = removeEmpty(target = matrix(perm_mask_cols, cols=length(perm_mask_cols), rows=1), margin="cols") + #ctable + masks = table(perm_mask_rows, perm_mask_cols, perm_len, perm_len) + if(n_non_varying_inds > 0){ + #truncate non varying rows + masks = masks[1:ncol(permutation)] + + #replicate mask from first feature of each partionton to entire partitions + if(sum(partitions)!=-1){ + for ( i in 1:ncol(partitions) ){ + p_start = as.scalar(partitions[1,i]) + p_end = as.scalar(partitions[2,i]) + proxy = masks[,p_start] %*% matrix(1, rows=1, cols=p_end-p_start) + masks[,p_start+1:p_end] = proxy + } + } + } + + # add inverted mask and revert order for desired shape for forward and backward pass + masks = rbind(!masks[nrow(masks)],masks, rev(!masks[1:nrow(masks)-1])) +} + +# Prepares the full mask for marginalization by repeating the rows +# +# INPUT: +# --------------------------------------------------------------------------------------- +# mask Boolean mask with 1, where from x, and 0, where integrated over background data. +# n_samples Number samples for which to replicate. +# --------------------------------------------------------------------------------------- +# +# OUTPUT: +# ----------------------------------------------------------------------------- +# x_mask_full A replicated mask. +# ----------------------------------------------------------------------------- +prepare_full_mask = function(Matrix[Double] mask, Integer n_samples) + return (Matrix[Double] x_mask_full){ + x_mask_full = u_repeatRows(mask,n_samples) +} + +# Prepares the masked background by replicating the samples and masking them using the full mask. +# +# INPUT: +# --------------------------------------------------------------------------------------- +# x_mask_full Boolean mask replicated orw-wise. +# X_bg_samples Samples from background. Either the same n samples for all permutaions or +# n*p samples, so each permutation has its own samples. +# n_perms_in_samples Number of sample sets to identify block which need to be replicated in X_bg_samples. +# --------------------------------------------------------------------------------------- +# +# OUTPUT: +# ----------------------------------------------------------------------------- +# x_mask_full A replicated mask. +# ----------------------------------------------------------------------------- +prepare_masked_X_bg = function(Matrix[Double] x_mask_full, Matrix[Double] X_bg_samples, Integer n_perms_in_samples) +return (Matrix[Double] masked_X_bg){ + #Repeat background once for every row in original mask. + #If the same samples are used for each permutation, simply repeat the entire samples accordingly + if (n_perms_in_samples <= 1){ + #Since x_mask_full was already replicated row-wise by the number of rows in X_bg_samples, we devide by it. + masked_X_bg = u_repeatMatrix(X_bg_samples, nrow(x_mask_full)/nrow(X_bg_samples)) + }else{ + # if X_bg_samples has independent samples for each perm, it holds n_samples*n_perms rows. + block_size = nrow(X_bg_samples)/n_perms_in_samples + masked_X_bg = u_repeatMatrixBlocks(X_bg_samples, block_size, nrow(x_mask_full)/block_size/n_perms_in_samples) + } + + masked_X_bg = masked_X_bg * !x_mask_full +} + +# Applies the masked background and boolen mask to individual instance of interest. +# +# INPUT: +# --------------------------------------------------------------------------------------- +# x_row Instance of interest as row-vector. +# x_mask_full Boolean mask replicated orw-wise. +# masked_X_bg Prepared background samples. +# --------------------------------------------------------------------------------------- +# +# OUTPUT: +# ----------------------------------------------------------------------------- +# X_masked Set of synthesized instances for x_row. +# ----------------------------------------------------------------------------- +apply_full_mask = function(Matrix[Double] x_row, Matrix[Double] x_mask_full, Matrix[Double] masked_X_bg) +return (Matrix[Double] X_masked){ + #add the masked data from this row + X_masked = masked_X_bg + (x_mask_full * x_row) +} + +# Removes all rows from the prepared masks and background data whenever their feature is marked as non-varying. +# +# INPUT: +# --------------------------------------------------------------------------------------- +# masks Prepared and replicated mask for a singel instance. +# masked_X_bg Prepared and replicated background data. +# full_permutations The permutations from which the masks and bd data were created. +# non_var_inds A row-vector containiing the indices that were found to be not varying for this instance. +# n_samples The number samples over which each row is integarted. +# --------------------------------------------------------------------------------------- +# +# OUTPUT: +# ----------------------------------------------------------------------------- +# sub_mask A subset of masks where for each permutation the rows that correspond to +# non-varying features are removed. +# sub_masked_X_bg A subset of the background data where for each permutation the rows that correspond to +# non-varying features are removed. +# ----------------------------------------------------------------------------- +remove_inds = function(Matrix[Double] masks, Matrix[Double] masked_X_bg, Matrix[Double] full_permutations, + Matrix[Double] non_var_inds, Integer n_samples) +return(Matrix[Double] sub_mask, Matrix[Double] sub_masked_X_bg){ + offsets = seq(0,length(full_permutations)-ncol(full_permutations), ncol(full_permutations)) + + ### + # get row indices from permutations + total_row_index = full_permutations + offsets + total_row_index = matrix(total_row_index, rows=length(total_row_index), cols=1) + + row_index = toOneHot(total_row_index, nrow(total_row_index)) + #### + # get indices for all permutations as boolean mask + # repeat inds for every permutation + non_var_inds = matrix(1, rows=nrow(full_permutations), cols=ncol(non_var_inds)) * non_var_inds + #add offset + non_var_total = non_var_inds + offsets + #reshape into col-vec + non_var_total = matrix(non_var_total,rows=length(non_var_total), cols=1, byrow=FALSE) + non_var_mask = toOneHot(non_var_total, nrow(total_row_index)) + + non_var_mask = colSums(non_var_mask) + + ### + # multiply to get mask + non_var_rows = row_index %*% t(non_var_mask) + + #### + # unfold to full mask length + # reshape to add for each permutations + reshaped_rows = matrix(non_var_rows, rows=ncol(full_permutations), cols=nrow(full_permutations), byrow=FALSE) + + reshaped_rows_full = matrix(0,rows=1,cols=ncol(reshaped_rows)) + + #rbind to manipulate all perms at once + if( sum(reshaped_rows[nrow(reshaped_rows)]) > 0 ){ + #fix last row issue by setting last zero to one, if 1 in last row + row_indicator = (!reshaped_rows) * seq(1, nrow(reshaped_rows), 1) + row_indicator = colMaxs(row_indicator) + row_indicator = t(toOneHot(t(row_indicator), nrow(reshaped_rows))) + reshaped_rows_2 = reshaped_rows[1:nrow(reshaped_rows)-1] + row_indicator[1:nrow(reshaped_rows)-1] + reshaped_rows_full = rbind(reshaped_rows_full,reshaped_rows,reshaped_rows_2) + }else{ + reshaped_rows_full = rbind(reshaped_rows_full,reshaped_rows,reshaped_rows[1:nrow(reshaped_rows)-1]) + } + #reshape into col-vec + non_var_total = matrix(reshaped_rows_full, rows=length(reshaped_rows_full), cols=1, byrow=FALSE) + + #replicate, if masks already replicated + if (n_samples > 1){ + non_var_total = matrix(1, rows=nrow(non_var_total), cols=n_samples) * non_var_total + non_var_total = matrix(non_var_total, rows=length(non_var_total), cols=1) + } + + #remove from mask according to this vector + sub_mask = removeEmpty(target=masks, select=!non_var_total, margin="rows") + #set to 1 where non varying + #sub_mask = removed_short_mask | non_var_mask[1, 1:ncol(removed_short_mask)] + sub_masked_X_bg = removeEmpty(target=masked_X_bg, select=!non_var_total, margin="rows") +} + +# Performs the integration/marginalization by computing means. +# +# INPUT: +# --------------------------------------------------------------------------------------- +# P Predictions from model. +# n_samples Number of samples over which to take the mean. +# --------------------------------------------------------------------------------------- +# +# OUTPUT: +# ----------------------------------------------------------------------------- +# P_means The means of the sample groups. Each row is one group with means in cols. +# ----------------------------------------------------------------------------- +compute_means_from_predictions = function(Matrix[Double] P, Integer n_samples) + return (Matrix[Double] P_means){ + n_features = nrow(P)/n_samples + + #transpose and reshape to concat all values of same type + # TODO: unneccessary for vectors, only t() would be needed + P = matrix(t(P), cols=1, rows=length(P)) + + #reshape, so all predictions from one batch are in one row + P = matrix(P, cols=n_samples, rows=length(P)/n_samples) + + #compute row means + P_means = rowMeans(P) + + # reshape and transpose to get back to input dimensions + P_means = matrix(P_means, rows=n_features, cols=length(P_means)/n_features) +} + +# Computes phis from predictions for a permutation. +# +# INPUT: +# --------------------------------------------------------------------------------------- +# P Predictions for multiple permutations. +# permutations Permutations to get the feature indices from. +# non_var_inds Matrix holding the indices of non-varying features in the permutation that were ignored +# during prediction. These will be remove from the during computation of the phis. +# n_partition_features Number of features that are in partitions - number of partitions: +# There is still one feature per partition kept in the perms! +# --------------------------------------------------------------------------------------- +# +# OUTPUT: +# ----------------------------------------------------------------------------- +# phis Phis or shapley values computed from this permutation. +# Every row holds the phis for the corresponding feature. +# ----------------------------------------------------------------------------- +compute_phis_from_prediction_means = function(Matrix[Double] P, Matrix[Double] permutations, + Matrix[Double] non_var_inds=as.matrix(-1), Integer n_partition_features = 0) +return(Matrix[Double] phis, Double expected){ + perm_len=ncol(permutations) + n_non_var_inds = 0 + partial_permutations = permutations + + if(sum(non_var_inds)>0){ + n_non_var_inds = ncol(non_var_inds) + #flatten perms to remove from all perms at once + perms_flattened = matrix(permutations, rows=length(permutations), cols=1) + rem_selector = outer(perms_flattened, non_var_inds, "==") + rem_selector = rowSums(rem_selector) + partial_permutations = removeEmpty(target=perms_flattened, select=!rem_selector, margin="rows") + #reshape + partial_permutations = matrix(partial_permutations, rows=perm_len-n_non_var_inds, cols=nrow(permutations)) + perm_len = perm_len-n_non_var_inds + } + + #reshape P to get one col per permutation + P_perm = matrix(P, rows=2*perm_len, cols=nrow(permutations), byrow=FALSE) + + #forwards phis + forward_phis = P_perm[2:perm_len+1] - P_perm[1:perm_len] + + #backward phis and fix first and last + backward_phis = rbind(P_perm[perm_len+2] - P_perm[1], P_perm[perm_len+3:2*perm_len] - P_perm[perm_len+2:2*perm_len-1], P_perm[perm_len+1] - P_perm[2*perm_len]) + #reverse to match order of features in permutation + backward_phis = rev(backward_phis) + #avg forward and backward + forward_phis = matrix(forward_phis, rows=length(forward_phis), cols=1, byrow=FALSE) + backward_phis = matrix(backward_phis, rows=length(backward_phis), cols=1, byrow=FALSE) + avg_phis = (forward_phis + backward_phis) / 2 + + #aggregate to get only one phi per feature (and implicitly add zeros for non var inds) + perms_flattened = matrix(partial_permutations, rows=length(partial_permutations), cols=1) + phis = aggregate(target=avg_phis, groups=perms_flattened, fn="mean", ngroups=ncol(permutations)+n_partition_features) + + #get expected from first row + expected=mean(P_perm[1]) +} + +# Removes features that are part of a partition. +# Keeps first feature of partition as proxy for partition. +# +# INPUT: +# --------------------------------------------------------------------------------------- +# features Matrix holding features in its cols. +# partitions Matirx holding start and end of partitions in the cols of the first and second row respectively. +# --------------------------------------------------------------------------------------- +# +# OUTPUT: +# ----------------------------------------------------------------------------- +# short_features Matrix like fatures, but with the ones from partitiones removed. +# ----------------------------------------------------------------------------- +removePartitionsFromFeatures = function(Matrix[Double] features, Matrix[Double] partitions) +return (Matrix[Double] short_features){ + #remove from features + rm_mask = matrix(0, rows=1, cols=ncol(features)) + for (i in 1:ncol(partitions)){ + part_start = as.scalar(partitions[1,i]) + part_end = as.scalar(partitions[2,i]) + #include part_start as proxy of partition + rm_mask = rm_mask + (features > part_start) * (features <= part_end) + } + short_features = removeEmpty(target=features, margin="cols", select=!rm_mask) +} + +######################## +# Utility Functions that might be worth refactoring into its own file +# They could be used in other scenarios as well +######################## + + +# Samples from the background data X_bg. +# The function first uses all background samples without replacement, but if more samples are requested than +# available in X_bg, it shuffles X_bg and pulls more samples from it, making it sampling with replacement. +# TODO: Might be replacable by other builtin for sampling in the future +# +# INPUT: +# --------------------------------------------------------------------------------------- +# X_bg Matrix of background data +# samples Number of total samples +# always_shuffle Boolean to enable reshuffleing of X_bg, defaults to false. +# seed A seed for the shuffleing etc. +# --------------------------------------------------------------------------------------- +# +# OUTPUT: +# ----------------------------------------------------------------------------- +# X_sample New Matrix containing #samples, from X_bg, potentially with replacement. +# ----------------------------------------------------------------------------- +u_sample_with_potential_replace = function(Matrix[Double] X_bg, Integer samples, Boolean always_shuffle = 0, Integer seed) +return (Matrix[Double] X_sample){ + number_of_bg_samples = nrow(X_bg) + + # expect to not use all from background and subsample from it + num_of_full_X_bg = 0 + num_of_remainder_samples = samples + + # shuffle background if desired + if(always_shuffle) { + X_bg = u_shuffle(X_bg) + } + + # list to store references to generated matrices so we can rbind them in one call + samples_list = list() + + # in case we need more than in the background data, use it multiple times with replacement + if(samples >= number_of_bg_samples) { + u_printShapMessage("WARN: More samples ("+toString(samples)+") are requested than available in the background dataset ("+toString(number_of_bg_samples)+"). Using replacement", 1) + + # get number of full sets of background by integer division + num_of_full_X_bg = samples %/% number_of_bg_samples + # get remaining samples using modulo + num_of_remainder_samples = samples %% number_of_bg_samples + + #use background data once + samples_list = append(samples_list, X_bg) + + if(num_of_full_X_bg > 1){ + # add shuffled versions of background data + for (i in 1:num_of_full_X_bg-1){ + samples_list = append(samples_list, u_shuffle(X_bg)) + } + } + } + + # sample from background dataset for remaining samples + if (num_of_remainder_samples > 0){ + # pick remaining samples + random_samples_indices = sample(number_of_bg_samples, num_of_remainder_samples, seed) + + #contingency table to pick rows by multiplication + R_cont = table(random_samples_indices, random_samples_indices, number_of_bg_samples, number_of_bg_samples) + + #pick samples by multiplication with contingency table of indices and removing empty rows + samples_list = append(samples_list, removeEmpty(target=t(t(X_bg) %*% R_cont), margin="rows")) + } + + + if ( length(samples_list) == 1){ + #dont copy if only one matrix is in list, since this is a heavy hitter + X_sample = as.matrix(samples_list[1]) + } else { + #single call to bind all generated samples into one large matrix + X_sample = rbind(samples_list) + } +} + +# Simple utility function to shuffle (from shuffle.dml, but without storing to file). Shuffles rows. +# +# INPUT: +# --------------------------------------------------------------------------------------- +# X Matrix to be shuffled +# --------------------------------------------------------------------------------------- +# +# OUTPUT: +# ----------------------------------------------------------------------------- +# X_shuffled Matrix like X but ... shuffled... +# ----------------------------------------------------------------------------- +u_shuffle = function(Matrix[Double] X) +return (Matrix[Double] X_shuffled){ + num_col = ncol(X) + # Random vector used to shuffle the dataset + y = rand(rows=nrow(X), cols=1, min=0, max=1, pdf="uniform") + X = order(target = cbind(X, y), by = num_col + 1) + X_shuffled = X[,1:num_col] +} + +# Simple utility function to create a range of integers from start to end. +# +# INPUT: +# --------------------------------------------------------------------------------------- +# start First integer of range. +# stop First integer of range. +# --------------------------------------------------------------------------------------- +# +# OUTPUT: +# ----------------------------------------------------------------------------- +# range Matrix with range from start to end in its cols. +# ----------------------------------------------------------------------------- +u_range = function(Integer start, Integer end) +return (Matrix[Double] range){ + range = t(cumsum(matrix(1, rows=end-start+1, cols=1))) + range = range+start-1 +} + +# Replicates rows of the input matrix n-times. +# +# Example: +# [1,2] +# [3,4] +# becomes +# [1,2] +# [1,2] +# [3,4] +# [3,4] +# +# INPUT: +# ----------------------------------------------------------------------------- +# M Matrix where rows will be replicated. +# n_times Number of replications. +# ----------------------------------------------------------------------------- +# +# OUTPUT: +# ----------------------------------------------------------------------------- +# M Matrix of replicated rows. +# ----------------------------------------------------------------------------- +u_repeatRows = function(Matrix[Double] M, Integer n_times) +return(Matrix[Double] M){ + #get indices for new rows (e.g. 1,1,1,2,2,2 for 2 rows, each replicated 3 times) + indices = ceil(seq(1,nrow(M)*n_times,1) / n_times) + + #to one hot, so we get a replication matrix R + R = toOneHot(indices, nrow(M)) + + #matrix-mulitply to repeat rows + M = R %*% M +} + +# Replicates matrix n-times block-wise. +# +# Example: +# [1,2] +# [3,4] +# becomes +# [1,2] +# [3,4] +# [1,2] +# [3,4] +# +# INPUT: +# ----------------------------------------------------------------------------- +# M Matrix where rows will be replicated. +# n_times Number of replications. +# ----------------------------------------------------------------------------- +# +# OUTPUT: +# ----------------------------------------------------------------------------- +# M Matrix of replicated rows. +# ----------------------------------------------------------------------------- +u_repeatMatrix = function(Matrix[Double] M, Integer n_times) +return(Matrix[Double] M){ + n_rows=nrow(M) + n_cols=ncol(M) + #reshape to row vector + M = matrix(M, rows=1, cols=length(M)) + #broadcast + M = matrix(1, rows=n_times, cols=1) * M + #reshape to get matrix + M = matrix(M, rows=n_rows*n_times, cols=n_cols) +} + +# Like repeatMatrix(), but alows to define parts of matrix as blocks to replicate n-rows as a block. +u_repeatMatrixBlocks = function(Matrix[Double] M, Integer rows_per_block, Integer n_times) +return(Matrix[Double] M){ + n_rows=nrow(M) + n_cols=ncol(M) + #reshape to row vector + M = matrix(M, rows=n_rows/rows_per_block, cols=n_cols*rows_per_block) + #repeat block rows + M = u_repeatRows(M, n_times) + #reshape to get matrix + M = matrix(M, rows=n_rows*n_times, cols=n_cols) +} + +#utility function to print with shap-explainer-tag +u_printShapMessage = function(String message, Boolean verbose){ + if(verbose){ + print("shap-explainer::"+message) + } +} + diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index 98f92ae55ed..ca31cd331ad 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -300,6 +300,7 @@ public enum Builtins { SELVARTHRESH("selectByVarThresh", true), SEQ("seq", false), SYMMETRICDIFFERENCE("symmetricDifference", true), + SHAPEXPLAINER("shapExplainer", true), SHERLOCK("sherlock", true), SHERLOCKPREDICT("sherlockPredict", true), SHORTESTPATH("shortestPath", true), diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinShapExplainerTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinShapExplainerTest.java new file mode 100644 index 00000000000..0c207d66f39 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinShapExplainerTest.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.builtin.part2; + + +import org.junit.Test; + +import java.util.HashMap; + +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; + +public class BuiltinShapExplainerTest extends AutomatedTestBase +{ + private static final String TEST_NAME = "shapExplainer"; + private static final String TEST_DIR = "functions/builtin/"; + private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinShapExplainerTest.class.getSimpleName() + "/"; + + //FIXME need for padding result with zero + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"})); + } + + @Test + public void testPrepareMaskForPermutation() { + runShapExplainerUnitTest("prepare_mask_for_permutation"); + } + + @Test + public void testPrepareMaskForPartialPermutation() { + runShapExplainerUnitTest("prepare_mask_for_partial_permutation"); + } + + @Test + public void testPrepareMaskForPartitionedPermutation() { + runShapExplainerUnitTest("prepare_mask_for_partitioned_permutation"); + } + + @Test + public void testComputeMeansFromPredictions() { + runShapExplainerUnitTest("compute_means_from_predictions"); + } + + @Test + public void testComputePhisFromPredictionMeans() { + runShapExplainerUnitTest("compute_phis_from_prediction_means"); + } + + @Test + public void testComputePhisFromPredictionMeansNonVars() { + runShapExplainerUnitTest("compute_phis_from_prediction_means_non_vars"); + } + + @Test + public void testPrepareFullMask() { + runShapExplainerUnitTest("prepare_full_mask"); + } + + @Test + public void testPrepareMaskedXBg() { + runShapExplainerUnitTest("prepare_masked_X_bg"); + } + + @Test + public void testPrepareMaskedXBgIndependentPerms() { + runShapExplainerUnitTest("prepare_masked_X_bg_independent_perms"); + } + + @Test + public void testApplyFullMask() { + runShapExplainerUnitTest("apply_full_mask"); + } + + private void runShapExplainerUnitTest(String testType) { + ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE); + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + String HOME = SCRIPT_DIR + TEST_DIR; + + try { + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + + //execute given unit test + fullDMLScriptName = HOME + TEST_NAME + "Unit.dml"; + programArgs = new String[]{"-args", testType, output("R"), output("R_expected")}; + runTest(true, false, null, -1); + + //compare to expected result + HashMap result = readDMLMatrixFromOutputDir("R"); + HashMap result_expected = readDMLMatrixFromOutputDir("R_expected"); + + TestUtils.compareMatrices(result, result_expected, 1e-3, testType+"_result", testType+"_expected"); + + } + finally { + rtplatform = platformOld; + } + } + + @Test + public void testShapExplainerDummyData(){ + runShapExplainerComponentTest(false); + } + //TODO add test with real data + + private void runShapExplainerComponentTest(Boolean useRealData) { + ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE); + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + String HOME = SCRIPT_DIR + TEST_DIR; + + try { + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + + //execute given unit test + fullDMLScriptName = HOME + TEST_NAME + "Component.dml"; + programArgs = new String[]{"-args", output("R"), output("R_expected")}; + runTest(true, false, null, -1); + + //compare to expected phis + HashMap result = readDMLMatrixFromOutputDir("R_phis"); + HashMap result_expected = readDMLMatrixFromOutputDir("R_expected_phis"); + + TestUtils.compareMatrices(result, result_expected, 1e-3, "explainer_result_phis", "explainer_expected_phis"); + + //compare to expected value of model + HashMap result_e = readDMLMatrixFromOutputDir("R_e"); + HashMap result_expected_e = readDMLMatrixFromOutputDir("R_expected_e"); + + TestUtils.compareMatrices(result_e, result_expected_e, 1e-3, "explainer_result_e", "explainer_expected_e"); + } + finally { + rtplatform = platformOld; + } + } +} diff --git a/src/test/scripts/functions/builtin/shapExplainerComponent.dml b/src/test/scripts/functions/builtin/shapExplainerComponent.dml new file mode 100644 index 00000000000..8bf444b665d --- /dev/null +++ b/src/test/scripts/functions/builtin/shapExplainerComponent.dml @@ -0,0 +1,57 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +######################################################################################################## +# THIS TEST IS HIGHLY DEPENDANT ON THE SAMPING! +# Changes in the dataset or number of samples etc. migh already be enough to change the expected result. +######################################################################################################## + +model_args = list(mult=1) +x_instances = matrix("100 200 300 100 300 400 100 100 500", rows=3, cols=3) +X_bg = matrix("11 12 13 21 22 23 31 32 33 41 42 43", rows=4, cols=3) +n_permutations = 2 +n_samples = 3 +seed = 42 + +#model for explainer test +dummyModel = function(Matrix[Double] X, Double mult) + return(Matrix[Double] P){ + P = rowSums(X)*mult +} + +[result_phis, result_e] = shapExplainer("dummyModel", model_args, x_instances, X_bg, n_permutations, n_samples, 0, as.matrix(-1), seed, 1) +result_e = cbind(as.matrix(result_e), as.matrix(0)) +#TODO for some reason storing just the scalar results in errors, so we create a small matrix by padding with a zero. +# Might be due to comma vs dot separation of decimals in strings if systems uses german local or other. + +expected_result_phis = matrix("69 168 267 69 268 367 69 68 467", rows=3, cols=3) +expected_result_e = matrix("96 0", rows=1, cols=2) + +path_phis=$1+"_phis" +path_e=$1+"_e" +path_expected_phis=$2+"_phis" +path_expected_e=$2+"_e" + +write(result_phis, path_phis) +write(result_e, path_e) +write(expected_result_phis, path_expected_phis) +write(expected_result_e, path_expected_e) + diff --git a/src/test/scripts/functions/builtin/shapExplainerUnit.dml b/src/test/scripts/functions/builtin/shapExplainerUnit.dml new file mode 100644 index 00000000000..c5f227a67e5 --- /dev/null +++ b/src/test/scripts/functions/builtin/shapExplainerUnit.dml @@ -0,0 +1,104 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +source("scripts/builtin/shapExplainer.dml") as shap; + +if ($1 == 'prepare_mask_for_permutation') { + #prepare_mask_for_permutation + perm = matrix("3 1 2", cols=3, rows=1) + result = shap::prepare_mask_for_permutation(permutation=perm) + expected_result = matrix("0 0 0 0 0 1 1 0 1 1 1 1 0 1 0 1 1 0", rows=6, cols=3) + +} else if ($1 == 'prepare_mask_for_partial_permutation') { + #prepare_mask_for_partial_permutation + perm = matrix("4 1 2", cols=3, rows=1) + result = shap::prepare_mask_for_permutation(permutation=perm, n_non_varying_inds=2) + expected_result = matrix("0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 1 1 0 1 0 0 1 1 0 1 1 1 1 0 1", rows=6, cols=5) + +} else if ($1 == 'prepare_mask_for_partitioned_permutation') { + #prepare_mask_for_partitioned_permutation + perm = matrix("4 1 2", cols=3, rows=1) + partitions = matrix("2 4 3 5", cols=2, rows=2) + result = shap::prepare_mask_for_permutation(permutation=perm, partitions=partitions) + expected_result = matrix("0 0 0 0 0 0 0 0 1 1 1 0 0 1 1 1 1 1 1 1 0 1 1 0 0 1 1 1 0 0", rows=6, cols=5) + +} else if ($1 == 'compute_means_from_predictions') { + #compute_means_from_predictions + p = matrix("2 3 3 4 4 5", rows=6, cols=1) + result = shap::compute_means_from_predictions(p, 2) + expected_result = matrix("2.5 3.5 4.5", rows=3, cols=1) + +} else if ($1 == 'compute_phis_from_prediction_means') { + #compute_phis_from_prediction_means + permutation = matrix("2 3 4 1 5", cols=5, rows=1) + P_perm = matrix("10 21 22 23 24 100 31 32 33 34", rows=10, cols=1) + result = shap::compute_phis_from_prediction_means(P=P_perm, permutations=permutation) + expected_result = matrix("1 38.5 1 1 48.5", rows=5, cols=1) + +} else if ($1 == 'compute_phis_from_prediction_means_non_vars') { + #compute_phis_from_prediction_means with non varying inds + permutation = matrix("3 4 2 1 5", cols=5, rows=1) + non_varying_inds= matrix("2", rows=1, cols=1) + P_perm = matrix("10 22 23 24 100 31 32 33", rows=8, cols=1) + result = shap::compute_phis_from_prediction_means(P=P_perm, permutations=permutation, non_var_inds=non_varying_inds) + expected_result = matrix("1 0 39.5 1 48.5", rows=5, cols=1) + +} else if ($1 == 'prepare_full_mask') { + #prepare_full_mask + mask = matrix("1 0 0 1", rows=2, cols=2) + result = shap::prepare_full_mask(mask, 3) + result = shap::u_repeatRows(mask,3) + expected_result = matrix("1 0 1 0 1 0 0 1 0 1 0 1", rows=6, cols=2) + +} else if ($1 == 'prepare_masked_X_bg') { + #prepare_masked_X_bg + mask = matrix("1 0 0 1", rows=2, cols=2) + full_mask = shap::prepare_full_mask(mask, 3) + X_bg_samples = matrix("11 12 21 22 31 32", rows=3, cols=2) + result = shap::prepare_masked_X_bg(full_mask, X_bg_samples, 0) + expected_result = matrix("0 12 0 22 0 32 11 0 21 0 31 0", rows=6, cols=2) + +} else if ($1 == 'prepare_masked_X_bg_independent_perms') { + #prepare_masked_X_bg for independent perms + mask = matrix("1 0 0 1 1 0 0 1", rows=4, cols=2) + full_mask = shap::prepare_full_mask(mask, 3) + X_bg_samples = matrix("11 12 21 22 31 32 41 42 51 52 61 62", rows=6, cols=2) + result = shap::prepare_masked_X_bg(full_mask, X_bg_samples, 2) + expected_result = matrix("0 12 0 22 0 32 11 0 21 0 31 0 0 42 0 52 0 62 41 0 51 0 61 0", rows=12, cols=2) + +} else if ($1 == 'apply_full_mask') { + #apply_full_mask + x_row = matrix("100 200", rows=1, cols=2) + mask = matrix("1 0 0 1", rows=2, cols=2) + full_mask = shap::prepare_full_mask(mask, 3) + X_bg_samples = matrix("11 12 21 22 31 32", rows=3, cols=2) + masked_X_bg = shap::prepare_masked_X_bg(full_mask, X_bg_samples, 0) + result = shap::apply_full_mask(x_row, full_mask, masked_X_bg) + expected_result = matrix("100 12 100 22 100 32 11 200 21 200 31 200", rows=6, cols=2) + +} else { + print("Test type "+$1+" unknown.") + result = matrix("100 100", rows=1, cols=2) + expected_result = matrix("0 0", rows=1, cols=2) +} + +write(result, $2) +write(expected_result, $3) +