Skip to content

Commit

Permalink
add StatsBase extension and weigthed stats
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaqz committed Nov 2, 2024
1 parent 667741d commit ee1e6ac
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 2 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
AlgebraOfGraphics = "cbdf2221-f076-402e-a563-3d30da359d67"
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[extensions]
DimensionalDataAlgebraOfGraphicsExt = "AlgebraOfGraphics"
DimensionalDataCategoricalArraysExt = "CategoricalArrays"
DimensionalDataMakie = "Makie"
DimensionalDataStatsBase = "StatsBase"

[compat]
Adapt = "2, 3.0, 4"
Expand Down Expand Up @@ -69,6 +71,7 @@ RecipesBase = "0.7, 0.8, 1"
SafeTestsets = "0.1"
SparseArrays = "1"
Statistics = "1"
StatsBase = "0.34"
StatsPlots = "0.15"
TableTraits = "1"
Tables = "1"
Expand Down Expand Up @@ -98,9 +101,10 @@ OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["AlgebraOfGraphics", "Aqua", "ArrayInterface", "BenchmarkTools", "CategoricalArrays", "ColorTypes", "Combinatorics", "CoordinateTransformations", "DataFrames", "Distributions", "Documenter", "GPUArrays", "ImageFiltering", "ImageTransformations", "JLArrays", "CairoMakie", "OffsetArrays", "Plots", "Random", "SafeTestsets", "StatsPlots", "Test", "Unitful"]
test = ["AlgebraOfGraphics", "Aqua", "ArrayInterface", "BenchmarkTools", "CategoricalArrays", "ColorTypes", "Combinatorics", "CoordinateTransformations", "DataFrames", "Distributions", "Documenter", "GPUArrays", "ImageFiltering", "ImageTransformations", "JLArrays", "CairoMakie", "OffsetArrays", "Plots", "Random", "SafeTestsets", "StatsBase", "StatsPlots", "Test", "Unitful"]
24 changes: 24 additions & 0 deletions ext/DimensionalDataStatsBase.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module DimensionalDataStatsBase

using DimensionalData
using Statistics
using StatsBase

const DD = DimensionalData

function Statistics.mean(A::AbstractDimArray, w::StatsBase.AbstractWeights; dims=:)
data = mean(parent(A), w; dims=dimnum(A, dims))
return rebuild(A, data, DD.reducedims(A, dims))
end
# For ambiguity
function Statistics.mean(A::AbstractDimArray, w::StatsBase.UnitWeights; dims=:)
data = mean(parent(A), w; dims=dimnum(A, dims))
return rebuild(A, data, DD.reducedims(A, dims))
end

function Base.sum(A::AbstractDimArray, w::AbstractWeights{<:Real}; dims=:)
data = sum(parent(A), w; dims=dimnum(A, dims))
return rebuild(A, data, DD.reducedims(A, dims))
end

end
11 changes: 11 additions & 0 deletions src/Dimensions/dimunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@ end

@inline Base.parent(r::DimUnitRange) = r.range

function Base.reduced_index(dur::DimUnitRange)
r = Base.reduced_index(parent(dur))
d = dims(dur)
d1 = if isreverse(d)
d[end:end]
else
d[begin:begin]
end
return DimUnitRange(r, d1)
end

@inline dims(r::DimUnitRange) = r.dim
@inline dims(rs::Tuple{DimUnitRange,Vararg{DimUnitRange}}) = map(dims, rs)

Expand Down
16 changes: 15 additions & 1 deletion test/ecosystem.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
using OffsetArrays, ImageFiltering, ImageTransformations, ArrayInterface, DimensionalData, Test
using DimensionalData
using Test

using OffsetArrays
using ImageFiltering
using ImageTransformations
using ArrayInterface
using StatsBase

using DimensionalData.Lookups

@testset "ArrayInterface" begin
Expand Down Expand Up @@ -70,3 +78,9 @@ end
imresize(parent(dims(da, X)), ratio=2)
imrotate(da, 0.3)
end

@testset "StatsBase" begin
da = rand(X(1:10), Y(1:3))
@test mean(da, weights([0.3,0.3,0.4]); dims=Y) == mean(parent(da), weights([0.3,0.3,0.4]); dims=2)
@test sum(da, weights([0.3,0.3,0.4]); dims=Y) == sum(parent(da), weights([0.3,0.3,0.4]); dims=2)
end

0 comments on commit ee1e6ac

Please sign in to comment.