diff --git a/Project.toml b/Project.toml index 072550f..54a8c6a 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661" MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Rasters = "a3a2b9e3-a471-40c9-b274-f788e487c689" +Shapley = "855ca7ad-a6ef-4de2-9ca8-726fe2a39065" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" diff --git a/src/SpeciesDistributionModels.jl b/src/SpeciesDistributionModels.jl index b3dc053..6bfecb5 100644 --- a/src/SpeciesDistributionModels.jl +++ b/src/SpeciesDistributionModels.jl @@ -2,17 +2,18 @@ module SpeciesDistributionModels using MLJ -import GLM, Tables, StatsBase, PrettyTables, Rasters, EvoTrees, DecisionTree +import GLM, Tables, StatsBase, PrettyTables, Rasters, EvoTrees, DecisionTree, Shapley using Rasters: Raster, RasterStack import CategoricalArrays.CategoricalArray -export SDMensemble, predict, sdm, select, machines, machine_keys +export SDMensemble, predict, sdm, select, machines, machine_keys, shap include("models.jl") include("ensemble.jl") include("predict.jl") +include("explain.jl") end diff --git a/src/ensemble.jl b/src/ensemble.jl index ff18d96..d70d047 100644 --- a/src/ensemble.jl +++ b/src/ensemble.jl @@ -53,7 +53,7 @@ function select( end function Base.show(io::IO, mime::MIME"text/plain", ensemble::SDMensemble) - println(io, "SDMensemble with $(length(ensemble)) models") + println(io, "SDMensemble with $(Base.length(ensemble)) models") println(io, "Model performance:") @@ -102,8 +102,8 @@ function sdm( @assert Tables.istable(presences) && Tables.istable(absence) - n_presence = length(Tables.rows(presences)) ## - n_absence = length(Tables.rows(absence)) + n_presence = Base.length(Tables.rows(presences)) ## + n_absence = Base.length(Tables.rows(absence)) n_total = n_presence + n_absence # merge presence and absence data into one namedtuple of vectors diff --git a/src/explain.jl b/src/explain.jl new file mode 100644 index 0000000..7ad6cab --- /dev/null +++ b/src/explain.jl @@ -0,0 +1,35 @@ +struct SDMshapley + values::Vector{<:NamedTuple} # Contains all shap values for all models + importances::Vector{<:NamedTuple} # Contais mean absolute shap for each variable for each model + ensemble::SDMensemble + summary +end + +Base.size(shap::SDMshapley) = Base.size(shap.values) +Base.length(shap::SDMshapley) = Base.length(shap.values) + +function Base.show(io::IO, mime::MIME"text/plain", shap::SDMshapley) + println(io, "Shapley evaluation for SDM ensemble with $(Base.length(shap)) models") + + println(io, "Mean feature importance:") + Base.show(io, mime, shap.summary) +end + +function shap(ensemble; parallelism = Shapley.CPUThreads(), n_samples = 50) + shapvalues = map(ensemble.trained_models) do model + Shapley.shapley( + x -> Float64.(MLJ.pdf.(MLJ.predict(model.machine, x), true)), # some ml models return float32s - where to handle this? + Shapley.MonteCarlo(parallelism, n_samples), + ensemble.data.predictor + ) + end + + importances = map(vals -> map(val -> mapreduce(abs, +, val) / Base.length(val), vals), shapvalues) + + summary = NamedTuple(var => mapreduce(x -> getfield(x, var), +, importances) / Base.length(importances) for var in ensemble.predictors) + + return SDMshapley(shapvalues, importances, ensemble, summary) +end + + +