Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify and document convolve #1452

Merged
merged 1 commit into from
Dec 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Distributions"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
authors = ["JuliaStats"]
version = "0.25.34"
version = "0.25.35"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ makedocs(
"reshape.md",
"cholesky.md",
"mixture.md",
"convolution.md",
"fit.md",
"extends.md",
"density_interface.md",
Expand Down
11 changes: 11 additions & 0 deletions docs/src/convolution.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Convolutions

A [convolution of two probability distributions](https://en.wikipedia.org/wiki/List_of_convolutions_of_probability_distributions)
is the probability distribution of the sum of two independent random variables that are
distributed according to these distributions.

The convolution of two distributions can be constructed with [`convolve`](@ref).

```@docs
convolve
```
63 changes: 23 additions & 40 deletions src/convolution.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
"""
convolve(d1::T, d2::T) where T<:Distribution -> Distribution

Convolve two distributions of the same type to yield the distribution corresponding to the
sum of independent random variables drawn from the underlying distributions.

The function is only defined in the cases where the convolution has a closed form as
defined here https://en.wikipedia.org/wiki/List_of_convolutions_of_probability_distributions

* `Bernoulli`
* `Binomial`
* `NegativeBinomial`
* `Geometric`
* `Poisson`
* `Normal`
* `Cauchy`
* `Chisq`
* `Exponential`
* `Gamma`
* `MultivariateNormal`
convolve(d1::Distribution, d2::Distribution)

Convolve two distributions and return the distribution corresponding to the sum of
independent random variables drawn from the underlying distributions.

Currently, the function is only defined in cases where the convolution has a closed form.
More precisely, the function is defined if the distributions of `d1` and `d2` are the same
and one of
* [`Bernoulli`](@ref)
* [`Binomial`](@ref)
* [`NegativeBinomial`](@ref)
* [`Geometric`](@ref)
* [`Poisson`](@ref)
* [`Normal`](@ref)
* [`Cauchy`](@ref)
* [`Chisq`](@ref)
* [`Exponential`](@ref)
* [`Gamma`](@ref)
* [`MvNormal`](@ref)

External links: [List of convolutions of probability distributions on Wikipedia](https://en.wikipedia.org/wiki/List_of_convolutions_of_probability_distributions)
"""
function convolve end
convolve(::Distribution, ::Distribution)

# discrete univariate
function convolve(d1::Bernoulli, d2::Bernoulli)
Expand Down Expand Up @@ -61,30 +63,11 @@ function convolve(d1::Gamma, d2::Gamma)
end

# continuous multivariate
# The first two methods exist for performance reasons to avoid unnecessarily converting
# PDMats to a Matrix
function convolve(
d1::Union{IsoNormal, ZeroMeanIsoNormal, DiagNormal, ZeroMeanDiagNormal},
d2::Union{IsoNormal, ZeroMeanIsoNormal, DiagNormal, ZeroMeanDiagNormal},
)
_check_convolution_shape(d1, d2)
return MvNormal(d1.μ .+ d2.μ, d1.Σ + d2.Σ)
end

function convolve(
d1::Union{FullNormal, ZeroMeanFullNormal},
d2::Union{FullNormal, ZeroMeanFullNormal},
)
_check_convolution_shape(d1, d2)
return MvNormal(d1.μ .+ d2.μ, d1.Σ.mat + d2.Σ.mat)
end

function convolve(d1::MvNormal, d2::MvNormal)
_check_convolution_shape(d1, d2)
return MvNormal(d1.μ .+ d2.μ, Matrix(d1.Σ) + Matrix(d2.Σ))
return MvNormal(d1.μ + d2.μ, d1.Σ + d2.Σ)
end


function _check_convolution_args(p1, p2)
p1 ≈ p2 || throw(ArgumentError(
"$(p1) !≈ $(p2): distribution parameters must be approximately equal",
Expand Down
122 changes: 32 additions & 90 deletions test/convolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ using LinearAlgebra
using Test

@testset "discrete univariate" begin

@testset "Bernoulli" begin
d1 = Bernoulli(0.1)
d2 = convolve(d1, d1)

@test isa(d2, Binomial)
d2 = @inferred(convolve(d1, d1))
@test d2 isa Binomial
@test d2.n == 2
@test d2.p == 0.1

Expand All @@ -20,43 +18,37 @@ using Test
# only works if p1 ≈ p2
d3 = Bernoulli(0.2)
@test_throws ArgumentError convolve(d1, d3)

end

@testset "Binomial" begin
d1 = Binomial(2, 0.1)
d2 = Binomial(5, 0.1)
d3 = convolve(d1, d2)

@test isa(d3, Binomial)
d3 = @inferred(convolve(d1, d2))
@test d3 isa Binomial
@test d3.n == 7
@test d3.p == 0.1

# only works if p1 ≈ p2
d4 = Binomial(2, 0.2)
@test_throws ArgumentError convolve(d1, d4)

end

@testset "NegativeBinomial" begin
d1 = NegativeBinomial(4, 0.1)
d2 = NegativeBinomial(1, 0.1)
d3 = convolve(d1, d2)

isa(d3, NegativeBinomial)
d3 = @inferred(convolve(d1, d2))
@test d3 isa NegativeBinomial
@test d3.r == 5
@test d3.p == 0.1

d4 = NegativeBinomial(1, 0.2)
@test_throws ArgumentError convolve(d1, d4)
end


@testset "Geometric" begin
d1 = Geometric(0.2)
d2 = convolve(d1, d1)

@test isa(d2, NegativeBinomial)
@test d2 isa NegativeBinomial
@test d2.p == 0.2

# cannot convolve a Geometric with a NegativeBinomial
Expand All @@ -70,50 +62,43 @@ using Test
@testset "Poisson" begin
d1 = Poisson(0.1)
d2 = Poisson(0.4)
d3 = convolve(d1, d2)

@test isa(d3, Poisson)
d3 = @inferred(convolve(d1, d2))
@test d3 isa Poisson
@test d3.λ == 0.5
end

end

@testset "continuous univariate" begin

@testset "Gaussian" begin
d1 = Normal(0.1, 0.2)
d2 = Normal(0.25, 1.7)
d3 = convolve(d1, d2)

@test isa(d3, Normal)
d3 = @inferred(convolve(d1, d2))
@test d3 isa Normal
@test d3.μ == 0.35
@test d3.σ == hypot(0.2, 1.7)
end

@testset "Cauchy" begin
d1 = Cauchy(0.2, 0.7)
d2 = Cauchy(1.9, 0.8)
d3 = convolve(d1, d2)

@test isa(d3, Cauchy)
d3 = @inferred(convolve(d1, d2))
@test d3 isa Cauchy
@test d3.μ == 2.1
@test d3.σ == 1.5
end

@testset "Chisq" begin
d1 = Chisq(0.1)
d2 = Chisq(0.3)
d3 = convolve(d1, d2)

@test isa(d3, Chisq)
d3 = @inferred(convolve(d1, d2))
@test d3 isa Chisq
@test d3.ν == 0.4
end

@testset "Exponential" begin
d1 = Exponential(0.7)
d2 = convolve(d1, d1)

@test isa(d2, Gamma)
d2 = @inferred(convolve(d1, d1))
@test d2 isa Gamma
@test d2.α == 2
@test d2.θ == 0.7

Expand All @@ -128,23 +113,19 @@ end
@testset "Gamma" begin
d1 = Gamma(0.1, 1.7)
d2 = Gamma(0.5, 1.7)
d3 = convolve(d1, d2)

@test isa(d3, Gamma)
d3 = @inferred(convolve(d1, d2))
@test d3 isa Gamma
@test d3.α == 0.6
@test d3.θ == 1.7

# only works if θ1 ≈ θ4
d4 = Gamma(1.2, 0.4)
@test_throws ArgumentError convolve(d1, d4)
end

end

@testset "continuous multivariate" begin

@testset "iso-/diag-normal" begin

in1 = MvNormal([1.2, 0.3], 2 * I)
in2 = MvNormal([-2.0, 6.9], 0.5 * I)

Expand All @@ -155,74 +136,35 @@ end
dn2 = MvNormal([-3.4, 1.2], Diagonal([3.2, 0.2]))

zmdn1 = MvNormal(Diagonal([1.2, 0.3]))
zmdn2 = MvNormal(Diagonal([-0.8, 1.0]))

dist_list = (in1, in2, zmin1, zmin2, dn1, dn2, zmdn1, zmdn2)

for (d1, d2) in Iterators.product(dist_list, dist_list)
d3 = convolve(d1, d2)
@test d3 isa Union{IsoNormal,DiagNormal,ZeroMeanIsoNormal,ZeroMeanDiagNormal}
@test d3.μ == d1.μ .+ d2.μ
@test Matrix(d3.Σ) == Matrix(d1.Σ + d2.Σ) # isequal not defined for PDMats
end

# erroring
in3 = MvNormal([1, 2, 3], 0.2 * I)
@test_throws ArgumentError convolve(in1, in3)
end


@testset "full-normal" begin
zmdn2 = MvNormal(Diagonal([0.8, 1.0]))

m1 = Symmetric(rand(2,2))
m1sq = m1^2
fn1 = MvNormal(ones(2), m1sq.data)
fn1 = MvNormal(ones(2), m1^2)

m2 = Symmetric(rand(2,2))
m2sq = m2^2
fn2 = MvNormal([2.1, 0.4], m2sq.data)
fn2 = MvNormal([2.1, 0.4], m2^2)

m3 = Symmetric(rand(2,2))
m3sq = m3^2
zm1 = MvNormal(m3sq.data)
zm1 = MvNormal(m3^2)

m4 = Symmetric(rand(2,2))
m4sq = m4^2
zm2 = MvNormal(m4sq.data)
zm2 = MvNormal(m4^2)

dist_list = (fn1, fn2, zm1, zm2)
dist_list = (in1, in2, zmin1, zmin2, dn1, dn2, zmdn1, zmdn2, fn1, fn2, zm1, zm2)

for (d1, d2) in Iterators.product(dist_list, dist_list)
d3 = convolve(d1, d2)
@test d3 isa Union{FullNormal,ZeroMeanFullNormal}
d3 = @inferred(convolve(d1, d2))
@test d3 isa MvNormal
@test d3.μ == d1.μ .+ d2.μ
@test d3.Σ.mat == d1.Σ.mat + d2.Σ.mat # isequal not defined for PDMats
@test Matrix(d3.Σ) == Matrix(d1.Σ + d2.Σ) # isequal not defined for PDMats
end

# erroring
in3 = MvNormal([1, 2, 3], 0.2 * I)
@test_throws ArgumentError convolve(in1, in3)

m5 = Symmetric(rand(3, 3))
m5sq = m5^2
fn3 = MvNormal(zeros(3), m5sq.data)
fn3 = MvNormal(zeros(3), m5^2)
@test_throws ArgumentError convolve(fn1, fn3)
end

@testset "mixed" begin

in1 = MvNormal([1.2, 0.3], 2 * I)
zmin1 = MvNormal(Zeros(2), 1.9 * I)
dn1 = MvNormal([0.0, 4.7], Diagonal([0.1, 1.8]))
zmdn1 = MvNormal(Diagonal([1.2, 0.3]))
m1 = Symmetric(rand(2, 2))
m1sq = m1^2
full = MvNormal(ones(2), m1sq.data)

dist_list = (in1, zmin1, dn1, zmdn1)

for (d1, d2) in Iterators.product((full, ), dist_list)
d3 = convolve(d1, d2)
@test isa(d3, MvNormal)
@test d3.μ == d1.μ .+ d2.μ
@test Matrix(d3.Σ) == Matrix(d1.Σ + d2.Σ) # isequal not defined for PDMats
end
end
end