Skip to content

Commit

Permalink
evaluate entire ensemble
Browse files Browse the repository at this point in the history
  • Loading branch information
tiemvanderdeure committed Jun 18, 2024
1 parent 7436746 commit a72b70e
Showing 1 changed file with 88 additions and 38 deletions.
126 changes: 88 additions & 38 deletions src/evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,11 @@ end
struct SDMgroupEvaluation <: AbstractVector{SDMmachineEvaluation}
machine_evaluations::Vector{SDMmachineEvaluation}
group::SDMgroup
measures
results
end

struct SDMensembleEvaluation <: AbstractVector{SDMgroupEvaluation}
group_evaluations::Vector{SDMgroupEvaluation}
ensemble::SDMensemble
measures
results
end

Expand All @@ -32,6 +29,9 @@ Base.size(group::SDMgroupEvaluation) = Base.size(group.machine_evaluations)

evaluation_sets(macheval::SDMmachineEvaluation) = keys(macheval.results)
evaluation_sets(e::SDMgroupOrEnsembleEvaluation) = evaluation_sets(e[1])

measures(macheval::SDMmachineEvaluation) = macheval.measures
measures(e::SDMgroupOrEnsembleEvaluation) = measures(e[1])
"""
machine_evaluations(eval)
Expand All @@ -53,31 +53,31 @@ machine_evaluations
function machine_evaluations(groupeval::SDMgroupEvaluation)
sets = evaluation_sets(groupeval)
map(sets) do set
map(keys(groupeval.measures)) do key
map(keys(measures(groupeval))) do key
map(groupeval) do e
e.results[set][key].score
end
end |> NamedTuple{keys(groupeval.measures)}
end |> NamedTuple{keys(measures(groupeval))}
end |> NamedTuple{sets}
end
function machine_evaluations(ensembleeval::SDMensembleEvaluation)
sets = keys(ensembleeval[1][1].results)
map(sets) do set
map(keys(ensembleeval.measures)) do key
map(keys(measures(ensembleeval))) do key
mapreduce(vcat, ensembleeval) do groupeval
map(groupeval) do e
e.results[set][key].score
end
end
end |> NamedTuple{keys(ensembleeval.measures)}
end |> NamedTuple{keys(measures(ensembleeval))}
end |> NamedTuple{sets}
end

## Show methods
function Base.show(io::IO, mime::MIME"text/plain", evaluation::SDMmachineEvaluation)
println(io, "SDMmachineEvaluation")

measures = collect(keys(evaluation.measures))
measures = collect(keys(measures(evaluation)))
sets = evaluation_sets(evaluation)
scores = map(sets) do s
round.(getfield.(collect(evaluation.results[s]), :score); digits = 2)
Expand All @@ -89,7 +89,7 @@ function Base.show(io::IO, mime::MIME"text/plain", evaluation::SDMmachineEvaluat
end

function Base.show(io::IO, mime::MIME"text/plain", evaluation::SDMgroupEvaluation)
measures = collect(keys(evaluation.measures))
measures = collect(keys(measures(evaluation)))
train_scores, test_scores = machine_evaluations(evaluation)
folds = getfield.(evaluation.group, :fold)

Expand All @@ -102,11 +102,13 @@ function Base.show(io::IO, mime::MIME"text/plain", evaluation::SDMgroupEvaluatio
end

function Base.show(io::IO, mime::MIME"text/plain", evaluation::SDMensembleEvaluation)
measures = collect(keys(evaluation.measures))
measures = collect(keys(evaluation[1][1].measures))#collect(keys(measures(evaluation)))
models = getfield.(evaluation.ensemble, :model_name)

# get scores from each group
scores = machine_evaluations.(evaluation)
machine_scores = machine_evaluations.(evaluation)
ensemble_scores = map(e -> map(e -> e.score, e), evaluation.results)
scores = vcat(machine_scores, ensemble_scores)
# get mean test and train from each group for each measure.
# then invert to a namedtuple where measures are keys
println(io, "$(typeof(evaluation)) with $(length(measures)) performance measures")
Expand All @@ -116,14 +118,14 @@ function Base.show(io::IO, mime::MIME"text/plain", evaluation::SDMensembleEvalua
s = map(scores) do score
map(Statistics.mean, getfield(score, k))
end |> Tables.columntable
PrettyTables.pretty_table(io, merge((; model = models), s))
PrettyTables.pretty_table(io, merge((; model = [models; "ensemble"]), s))
end
end


## Core evaluator
# internal method to get a vector of scores from y_hats, ys, and a namedtuple of measures
function _evaluate(y_hat::MLJBase.UnivariateFiniteArray, y::CategoricalArrays.CategoricalArray, measures)
function _evaluate(y_hat::MLJBase.UnivariateFiniteArray, y::AbstractArray{<:CategoricalArrays.CategoricalValue}, measures)
kinds_of_proxy = map(StatisticalMeasuresBase.kind_of_proxy, measures)

# if any are literal targets (threshold-dependent), compute the confusion matrices outside the loop
Expand Down Expand Up @@ -152,8 +154,10 @@ function _conf_mats_from_thresholds(scores, y, thresholds)
levels = [false, true]
# use the internal method to avoid constructing indexer every time
indexer = StatisticalMeasures.LittleDict(levels[i] => i for i in eachindex(levels)) |> StatisticalMeasures.freeze
# preallocate y_
y_ = boolean_categorical(falses(size(scores)...))
broadcast(thresholds) do t
y_ = boolean_categorical(scores .>= t)
broadcast!(>=(t), y_, scores)
StatisticalMeasures.ConfusionMatrices._confmat(y_, y, indexer, levels, true)
end
end
Expand All @@ -162,64 +166,110 @@ end
_ev_predict(sdm_machine, data::AbstractVector{<:Integer}) = MLJBase.predict(sdm_machine.machine; rows = data)
_ev_predict(sdm_machine, data) = MLJBase.predict(sdm_machine, data)
# Evaluate a single SDMmachine
function _evaluate(sdm_machine::SDMmachine, measures::NamedTuple, train, test, validation)

function _ev_ydata(sdm_machine, train, test, validation)
machdata = data(sdm_machine)
d_y = (;)
if train
d_y = merge(d_y, (;train = view(machdata.response, sdm_machine.train_rows)))
end
if test
d_y = merge(d_y, (;test = view(machdata.response, sdm_machine.test_rows)))
end
if !isempty(validation)
d_y = merge(d_y, (;validation = validation[2]))
end
return d_y
end

function _ev_predict(sdm_machine::SDMmachine, train, test, validation)
machdata = data(sdm_machine)
# set up namedtuple with data/rows, throw out if nothing/false
d_X = (;)
d_y = (;)
if train
d_X = merge(d_X, (;train = Tables.subset(machdata.predictor, sdm_machine.train_rows)))
d_y = merge(d_y, (;train = machdata.response[sdm_machine.train_rows]))
end
if test
d_X = merge(d_X, (;test = Tables.subset(machdata.predictor, sdm_machine.test_rows)))
d_y = merge(d_y, (;test = machdata.response[sdm_machine.test_rows]))
end
if !isempty(validation)
d_X = merge(d_X, (;validation = validation[1]))
d_y = merge(d_y, (;validation = validation[2]))
end
map(X -> MLJBase.predict(sdm_machine.machine, X), d_X)
end

results = map(d_X, d_y) do X, y
y_hat = MLJBase.predict(sdm_machine.machine, X)
function _evaluate(sdm_machine::SDMmachine, y_hats::NamedTuple, ys::NamedTuple, measures::NamedTuple)
results = map(y_hats, ys) do y_hat, y
_evaluate(y_hat, y, measures)
end

return SDMmachineEvaluation(sdm_machine, measures, results)
end

function _evaluate(sdm_machine::SDMmachine, measures::NamedTuple, train, test, validation)
y_hat = _ev_predict(sdm_machine, train, test, validation)
y = _ev_ydata(sdm_machine, train, test, validation)
_evaluate(sdm_machine, y_hat, y, measures)
end

# Evaluate a group
function _evaluate(group::SDMgroup, measures, train, test, validation)
function _evaluate(group::SDMgroup, measures, train::Bool, test::Bool, validation::Tuple)
machine_evaluations = map(m -> (_evaluate(m, measures, train, test, validation)), group)

# average group prediction
p = predict(group, data(group).predictor, reducer = Statistics.mean)
y_hat = MLJBase.UnivariateFinite(boolean_categorical([false, true]), p, augment = true)
return SDMgroupEvaluation(
machine_evaluations,
group
)
end

y = data(group).response
group_evaluation = _evaluate(y_hat, y, measures)
function _evaluate(group::SDMgroup, yhat, y, measures::NamedTuple)
machine_evaluations = map((m, yhat_, y_) -> _evaluate(m, yhat_, y_, measures), group, yhat, y)

return SDMgroupEvaluation(
machine_evaluations,
group,
measures,
group_evaluation
group
)
end

function _evaluate(ensemble::SDMensemble, measures, train, test, validation)
group_evaluations = map(m -> (_evaluate(m, measures, train, test, validation)), ensemble)


function _evaluate(ensemble::SDMensemble, measures, train, test, validation, reducer = Statistics.mean)
# works only if all groups have the same number of folds, for now

# get y_hats
y_hats = map(g -> map(m -> _ev_predict(m, train, test, validation), g), ensemble)
ys = map(m -> _ev_ydata(m, train, test, validation), first(ensemble))
# this whole thing returns the reduced (usually mean) for each fold, across model, for test/train/validation
y_hats_reduced = map(y_hats...) do y...
map(y...) do y_...
p = map(x -> pdf.(x, true), y_) # probabilities from multivariatefinites
ps_red = map((p...) -> reducer(p), p...)
MLJBase.UnivariateFinite(boolean_categorical([false, true]), ps_red, augment = true)
end
end

group_evaluations = map((g, y_hat) -> (_evaluate(g, y_hat, ys, measures)), ensemble, y_hats)

# average ensemble prediction
p = predict(ensemble, data(ensemble).predictor, reducer = Statistics.mean)
y_hat = MLJBase.UnivariateFinite(boolean_categorical([false, true]), p, augment = true)
y = data(ensemble).response
ensemble_evaluation = _evaluate(y_hat, y, measures)
ensemble_evaluation = [
map((y_hat, y) -> _evaluate(y_hat, y, measures), y_hats_reduced[i], ys[i])
for i in eachindex(y_hats_reduced)
]

# invert the structure so it becomes a namedtuple (train/test) of namedtuples (measures) containing vectors (folds)
sets = keys(ys[1])
ensemble_evaluation = map(sets) do set
map(keys(measures)) do key
map((:score, :threshold)) do s
map(ensemble_evaluation) do e
e[set][key][s]
end
end |> NamedTuple{(:score, :threshold)}
end |> NamedTuple{keys(measures)}
end |> NamedTuple{sets}

return SDMensembleEvaluation(
group_evaluations,
ensemble,
measures,
ensemble_evaluation
)
end

0 comments on commit a72b70e

Please sign in to comment.