-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Some work on bayesian binomial model
- Loading branch information
Piotr Chlebicki
committed
Apr 19, 2024
1 parent
7f197bf
commit d3a1223
Showing
8 changed files
with
422 additions
and
103 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
function sample_gamma_1_cond(grid, n, N, γ₂, M, m, μ_γ₁, μ_γ₂, ρ, σ_γ₁, σ_γ₂, ε = 1e-6) | ||
# get posteriori normal parameters | ||
μ_γ₁_post = μ_γ₁ + ρ * σ_γ₂ / σ_γ₁ * (γ₂ - μ_γ₂) | ||
σ_γ₁_post = sqrt(1 - ρ ^ 2) * σ_γ₁ | ||
distr = Normal(μ_γ₁, σ_γ₁) | ||
# compute R_i's | ||
#log_N_sq = log.(N) .^ 2 | ||
#= R = zeros(BigFloat, length(M)) | ||
next_iter = true | ||
t = 0 | ||
while next_iter | ||
println(t) | ||
# Rprev = copy(R) | ||
# iterate sum | ||
R_add = cgf.(distr, (M .- t) .* log.(N)) .- logfactorial(t) | ||
R .+= exp.(R_add) .* (-1) ^ t | ||
println(R_add) | ||
# check convergence | ||
next_iter = any(exp.(R_add) .< ε) | t > 1 | ||
t += 1 | ||
end # end while | ||
println(R) =# | ||
f(x, p) = exp.(x .* M .* log.(N) - N .^ x) .* pdf(distr, x) | ||
prob = IntegralProblem(f, [-Inf, Inf]) | ||
R = solve(prob, HCubatureJL(), reltol = ε, abstol = ε) | ||
#println(R) | ||
R .*= exp.(-logfactorial.(M)) # <--- if this fails this is probably why | ||
#println(R) | ||
#error("bcd") | ||
|
||
# shift grid towards the mean since it is most probable | ||
grid1 = grid .+ μ_γ₁_post | ||
# get unscaled density function and evaluate it on a grid | ||
function density_function(x) | ||
μ = (N .^ x) .* ((n ./ N) .^ γ₂) | ||
μ = 1 ./ (1 .+ exp.(-μ)) | ||
lξ = x .* log.(N) | ||
|
||
#res = (μ .^ m) .* ((1 .- μ) .^ (M .- m)) | ||
res = zeros(BigFloat, length(M)) | ||
res .+= m .* log.(μ) .+ (M .- m) .* log.(1 .- μ) .- log.(R) .- logfactorial.(M .- m) .- exp.(lξ) .+ M .* lξ | ||
#println(exp(sum(res))) | ||
exp(sum(res)) .* pdf(Normal(μ_γ₁_post, σ_γ₁_post), x) | ||
end # end funciton | ||
evaluated_denisty = density_function.(grid1) | ||
#println(evaluated_denisty) | ||
evaluated_denisty ./= sum(evaluated_denisty) | ||
# sample acording to evaluation | ||
grid1[rand(Categorical(evaluated_denisty))] | ||
end # end funciton | ||
|
||
function sample_gamma_2_cond(grid, n, N, γ₁, M, m, μ_γ₁, μ_γ₂, ρ, σ_γ₁, σ_γ₂) | ||
# get posteriori normal parameters | ||
μ_γ₂_post = μ_γ₂ + ρ * σ_γ₁ / σ_γ₂ * (γ₁ - μ_γ₁) | ||
σ_γ₂_post = sqrt(1 - ρ ^ 2) * σ_γ₂ | ||
# shift grid towards the mean since it is most probable | ||
grid1 = grid .- μ_γ₂_post | ||
# get unscaled density function and evaluate it on a grid | ||
function density_function(x) | ||
μ = (N .^ γ₁) .* ((n ./ N) .^ x) | ||
μ = 1 ./ (1 .+ exp.(-μ)) | ||
#ξ = N .^ γ₁ | ||
|
||
res = logfactorial.(M) .- logfactorial.(M .- m) .- logfactorial.(m) .+ m .* log.(μ).+ (M .- m) .* log.(1 .- μ) | ||
exp(sum(res)) * pdf(Normal(μ_γ₂_post, σ_γ₂_post), x) | ||
end # end funciton | ||
evaluated_denisty = density_function.(grid1) | ||
evaluated_denisty ./= sum(evaluated_denisty) | ||
# sample acording to evaluation | ||
grid1[rand(Categorical(evaluated_denisty))] | ||
end # end funciton | ||
|
||
function sample_M_cond(n, N, m, γ₁, γ₂) | ||
# compute ξ, μ | ||
μ = (N .^ γ₁) .* ((n ./ N) .^ γ₂) | ||
μ = 1 ./ (1 .+ exp.(-μ)) | ||
ξ = N .^ γ₁ | ||
# draw M-m vector from poisson intependently | ||
M_minus_m = reduce(vcat, rand.(Poisson.(ξ .* (1 .- μ)), 1)) | ||
# return M = (M-m) + increment | ||
m + M_minus_m | ||
end # end funciton | ||
|
||
function gibbs_sampler_binomial_model(start, grid, iter, n, N, m, μ_γ₁, μ_γ₂, σ_γ₁, σ_γ₂, ρ, ε = 1e-6) | ||
# create storage vectors | ||
M = start[1] | ||
γ₁ = start[2] | ||
γ₂ = start[3] | ||
|
||
storage = [[M], [γ₁], [γ₂]] | ||
|
||
for k in iter | ||
# sample M conditional on γ₁ and γ₂ | ||
M = sample_M_cond(n, N, m, γ₁, γ₂) | ||
#println(M) | ||
# sample γ₁ conditional on M and γ₂ | ||
γ₁ = sample_gamma_1_cond( | ||
grid, n, N, γ₂, | ||
M, m, μ_γ₁, μ_γ₂, | ||
ρ, σ_γ₁, σ_γ₂, ε | ||
) | ||
#println(γ₁) | ||
# sample γ₂ conditional on γ₁ and M | ||
γ₂ = sample_gamma_2_cond( | ||
grid, n, N, γ₁, | ||
M, m, μ_γ₁, μ_γ₂, | ||
ρ, σ_γ₁, σ_γ₂ | ||
) | ||
#println(γ₂) | ||
# store them | ||
append!(storage[1], M) | ||
append!(storage[2], γ₁) | ||
append!(storage[3], γ₂) | ||
end # end for | ||
# return stored values | ||
storage | ||
end # end funciton |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
N,n,m,M,ξ | ||
4946,501,616,907,902.4087332179555 | ||
4886,480,661,965,893.6403502565424 | ||
4935,506,608,887,900.8027961854027 | ||
4943,488,598,895,901.9708212977752 | ||
5079,535,647,943,921.7700017557884 | ||
5054,516,587,897,918.1384813234845 | ||
5027,521,658,922,914.2124012586877 | ||
4890,507,604,862,894.2255764308815 | ||
5022,490,593,868,913.4848871407867 | ||
4928,477,589,888,899.7804635417899 | ||
4983,513,636,930,907.8052912210046 | ||
4968,485,595,917,905.6184666779144 | ||
5144,543,716,970,931.195284188665 | ||
4993,507,643,942,909.2624427471206 | ||
4854,472,605,879,888.9550830898875 | ||
4988,499,588,884,908.5339400169311 | ||
4977,489,617,917,906.9307196114336 | ||
5053,524,617,902,917.9931458823555 | ||
4950,531,629,880,902.9925331538545 | ||
4845,492,551,837,887.6362400289374 |