Skip to content

Commit

Permalink
Merge pull request #2 from tiemvanderdeure/explain
Browse files Browse the repository at this point in the history
add shapley values
  • Loading branch information
tiemvanderdeure authored Dec 1, 2023
2 parents f325a76 + 1cd9062 commit 7e4315a
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 5 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Rasters = "a3a2b9e3-a471-40c9-b274-f788e487c689"
Shapley = "855ca7ad-a6ef-4de2-9ca8-726fe2a39065"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

Expand Down
5 changes: 3 additions & 2 deletions src/SpeciesDistributionModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@ module SpeciesDistributionModels

using MLJ

import GLM, Tables, StatsBase, PrettyTables, Rasters, EvoTrees, DecisionTree
import GLM, Tables, StatsBase, PrettyTables, Rasters, EvoTrees, DecisionTree, Shapley

using Rasters: Raster, RasterStack

import CategoricalArrays.CategoricalArray

export SDMensemble, predict, sdm, select, machines, machine_keys
export SDMensemble, predict, sdm, select, machines, machine_keys, shap

include("models.jl")
include("ensemble.jl")
include("predict.jl")
include("explain.jl")

end

Expand Down
6 changes: 3 additions & 3 deletions src/ensemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function select(
end

function Base.show(io::IO, mime::MIME"text/plain", ensemble::SDMensemble)
println(io, "SDMensemble with $(length(ensemble)) models")
println(io, "SDMensemble with $(Base.length(ensemble)) models")

println(io, "Model performance:")

Expand Down Expand Up @@ -102,8 +102,8 @@ function sdm(

@assert Tables.istable(presences) && Tables.istable(absence)

n_presence = length(Tables.rows(presences)) ##
n_absence = length(Tables.rows(absence))
n_presence = Base.length(Tables.rows(presences)) ##
n_absence = Base.length(Tables.rows(absence))
n_total = n_presence + n_absence

# merge presence and absence data into one namedtuple of vectors
Expand Down
35 changes: 35 additions & 0 deletions src/explain.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
struct SDMshapley
values::Vector{<:NamedTuple} # Contains all shap values for all models
importances::Vector{<:NamedTuple} # Contais mean absolute shap for each variable for each model
ensemble::SDMensemble
summary
end

Base.size(shap::SDMshapley) = Base.size(shap.values)
Base.length(shap::SDMshapley) = Base.length(shap.values)

function Base.show(io::IO, mime::MIME"text/plain", shap::SDMshapley)
println(io, "Shapley evaluation for SDM ensemble with $(Base.length(shap)) models")

println(io, "Mean feature importance:")
Base.show(io, mime, shap.summary)
end

function shap(ensemble; parallelism = Shapley.CPUThreads(), n_samples = 50)
shapvalues = map(ensemble.trained_models) do model
Shapley.shapley(
x -> Float64.(MLJ.pdf.(MLJ.predict(model.machine, x), true)), # some ml models return float32s - where to handle this?
Shapley.MonteCarlo(parallelism, n_samples),
ensemble.data.predictor
)
end

importances = map(vals -> map(val -> mapreduce(abs, +, val) / Base.length(val), vals), shapvalues)

summary = NamedTuple(var => mapreduce(x -> getfield(x, var), +, importances) / Base.length(importances) for var in ensemble.predictors)

return SDMshapley(shapvalues, importances, ensemble, summary)
end



0 comments on commit 7e4315a

Please sign in to comment.