diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b2ce8c..158f842 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [0.19.0] - 2024-12-25 + +### Changed +- Merry Christmas +- `NormalInvChiSquared`, `NormalGamma`, and `NormalInvGamme` `PpCache` for Gaussian conjugate analysis changed. `ln_pp_with_cache` is much faster. +- `Gamma` `PpCache` for Poisson conjugate analysis has been optimized. `ln_pp_with_cache` is faster. + ## [0.18.0] - 2024-06-24 ### Added @@ -215,6 +222,7 @@ - Remove dependency on `quadrature` crate in favor of hand-rolled adaptive Simpson's rule, which handles multimodal distributions better. +[0.19.0]: https://github.com/promise-ai/rv/compare/v0.18.0...v0.19.0 [0.18.0]: https://github.com/promise-ai/rv/compare/v0.17.0...v0.18.0 [0.17.0]: https://github.com/promise-ai/rv/compare/v0.16.5...v0.17.0 [0.16.5]: https://github.com/promise-ai/rv/compare/v0.16.4...v0.16.5 diff --git a/Cargo.lock b/Cargo.lock index 1aed388..92bcfd9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -888,7 +888,7 @@ dependencies = [ [[package]] name = "rv" -version = "0.18.0" +version = "0.19.0" dependencies = [ "approx", "argmin", diff --git a/Cargo.toml b/Cargo.toml index e1adc38..26049bf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rv" -version = "0.18.0" +version = "0.19.0" authors = ["Baxter Eaves", "Michael Schmidt", "Chad Scherrer"] description = "Random variables" repository = "https://github.com/promised-ai/rv" @@ -13,7 +13,7 @@ include = ["README.md", "src/**/*", "benches/*", "Cargo.toml"] rust-version = "1.72" [badges] -github = { repository = "promised-ai/rv", tag = "v0.17.0" } +github = { repository = "promised-ai/rv", tag = "v0.19.0" } maintenance = { status = "actively-developed" } [dependencies] @@ -90,3 +90,11 @@ required-features = ["arraydist"] [[bench]] name = "mixture_entropy" harness = false + +[[bench]] +name = "nix" +harness = false + +[[bench]] +name = "ng" +harness = false diff --git a/benches/ng.rs b/benches/ng.rs new file mode 100644 index 0000000..ad73a2b --- /dev/null +++ b/benches/ng.rs @@ -0,0 +1,57 @@ +use criterion::black_box; +use criterion::BatchSize; +use criterion::Criterion; +use criterion::{criterion_group, criterion_main}; +use rv::data::GaussianSuffStat; +use rv::dist::Gaussian; +use rv::dist::NormalGamma; +use rv::traits::*; + +fn bench_ng_postpred(c: &mut Criterion) { + let mut group = c.benchmark_group("NG ln pp(x)"); + let ng = NormalGamma::new_unchecked(0.1, 1.2, 2.3, 3.4); + let mut rng = rand::thread_rng(); + let g = Gaussian::standard(); + + group.bench_function(format!("No cache"), |b| { + b.iter_batched( + || { + let stat = { + let mut stat = GaussianSuffStat::new(); + g.sample_stream(&mut rng).take(10).for_each(|x: f64| { + stat.observe(&x); + }); + stat + }; + let y: f64 = g.draw(&mut rng); + (y, stat) + }, + |(y, stat)| { + black_box(ng.ln_pp(&y, &DataOrSuffStat::SuffStat(&stat))) + }, + BatchSize::SmallInput, + ); + }); + + group.bench_function(format!("With cache"), |b| { + b.iter_batched( + || { + let stat = { + let mut stat = GaussianSuffStat::new(); + g.sample_stream(&mut rng).take(10).for_each(|x: f64| { + stat.observe(&x); + }); + stat + }; + let y: f64 = g.draw(&mut rng); + let cache = ng.ln_pp_cache(&DataOrSuffStat::SuffStat(&stat)); + (y, cache) + }, + |(y, cache)| black_box(ng.ln_pp_with_cache(&cache, &y)), + BatchSize::SmallInput, + ); + }); +} + +criterion_group!(ng_benches, bench_ng_postpred); +criterion_main!(ng_benches); diff --git a/benches/nix.rs b/benches/nix.rs new file mode 100644 index 0000000..67c4332 --- /dev/null +++ b/benches/nix.rs @@ -0,0 +1,99 @@ +use criterion::black_box; +use criterion::BatchSize; +use criterion::Criterion; +use criterion::{criterion_group, criterion_main}; +use rv::data::GaussianSuffStat; +use rv::dist::Gaussian; +use rv::dist::NormalInvChiSquared; +use rv::traits::*; + +fn bench_nix_postpred(c: &mut Criterion) { + let mut group = c.benchmark_group("NIX ln pp(x)"); + let nix = NormalInvChiSquared::new_unchecked(0.1, 1.2, 2.3, 3.4); + let mut rng = rand::thread_rng(); + let g = Gaussian::standard(); + + group.bench_function(format!("No cache"), |b| { + b.iter_batched( + || { + let stat = { + let mut stat = GaussianSuffStat::new(); + g.sample_stream(&mut rng).take(10).for_each(|x: f64| { + stat.observe(&x); + }); + stat + }; + let y: f64 = g.draw(&mut rng); + (y, stat) + }, + |(y, stat)| { + black_box(nix.ln_pp(&y, &DataOrSuffStat::SuffStat(&stat))) + }, + BatchSize::SmallInput, + ); + }); + + group.bench_function(format!("With cache"), |b| { + b.iter_batched( + || { + let stat = { + let mut stat = GaussianSuffStat::new(); + g.sample_stream(&mut rng).take(10).for_each(|x: f64| { + stat.observe(&x); + }); + stat + }; + let y: f64 = g.draw(&mut rng); + let cache = nix.ln_pp_cache(&DataOrSuffStat::SuffStat(&stat)); + (y, cache) + }, + |(y, cache)| black_box(nix.ln_pp_with_cache(&cache, &y)), + BatchSize::SmallInput, + ); + }); +} + +fn bench_gauss_stat(c: &mut Criterion) { + let mut group = c.benchmark_group("Gaussian Suffstat"); + + let mut rng = rand::thread_rng(); + let g = Gaussian::standard(); + + group.bench_function(format!("Forget"), |b| { + b.iter_batched( + || { + let mut stat = GaussianSuffStat::new(); + for _ in 0..3 { + let x: f64 = g.draw(&mut rng); + stat.observe(&x); + } + let x: f64 = g.draw(&mut rng); + stat.observe(&x); + (x, stat) + }, + |(x, mut stat)| { + black_box(stat.forget(&x)); + }, + BatchSize::SmallInput, + ); + }); + + group.bench_function(format!("Observe"), |b| { + b.iter_batched( + || { + let mut stat = GaussianSuffStat::new(); + let x: f64 = g.draw(&mut rng); + stat.observe(&x); + let x: f64 = g.draw(&mut rng); + (x, stat) + }, + |(x, mut stat)| { + black_box(stat.observe(&x)); + }, + BatchSize::SmallInput, + ); + }); +} + +criterion_group!(nix_benches, bench_nix_postpred, bench_gauss_stat); +criterion_main!(nix_benches); diff --git a/src/data/mod.rs b/src/data/mod.rs index 80e2af8..391ad22 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -17,6 +17,8 @@ pub use stat::UnitPowerLawSuffStat; use crate::dist::{ Bernoulli, Categorical, Gaussian, InvGamma, InvGaussian, Poisson, }; +use crate::traits::ConjugatePrior; +use crate::traits::HasDensity; use crate::traits::{HasSuffStat, SuffStat}; pub type BernoulliData<'a, X> = DataOrSuffStat<'a, X, Bernoulli>; @@ -159,7 +161,7 @@ where pub fn n(&self) -> usize { match &self { DataOrSuffStat::Data(data) => data.len(), - DataOrSuffStat::SuffStat(s) => s.n(), + DataOrSuffStat::SuffStat(s) => >::n(s), } } @@ -212,39 +214,35 @@ where /// Convert a `DataOrSuffStat` into a `Stat` #[inline] -pub fn extract_stat( - x: &DataOrSuffStat, - stat_ctor: Ctor, -) -> Fx::Stat +pub fn extract_stat(pr: &Pr, x: &DataOrSuffStat) -> Fx::Stat where - Fx: HasSuffStat, + Fx: HasSuffStat + HasDensity, Fx::Stat: Clone, - Ctor: Fn() -> Fx::Stat, + Pr: ConjugatePrior, { match x { DataOrSuffStat::SuffStat(s) => (*s).clone(), DataOrSuffStat::Data(xs) => { - let mut stat = stat_ctor(); - xs.iter().for_each(|y| stat.observe(y)); + let mut stat = pr.empty_stat(); + stat.observe_many(xs); stat } } } /// Convert a `DataOrSuffStat` into a `Stat` then do something with it -#[inline] -pub fn extract_stat_then( +pub fn extract_stat_then( + pr: &Pr, x: &DataOrSuffStat, - stat_ctor: Ctor, f_stat: Fnx, ) -> Y where - Fx: HasSuffStat, + Fx: HasSuffStat + HasDensity, Fx::Stat: Clone, - Ctor: Fn() -> Fx::Stat, + Pr: ConjugatePrior, Fnx: Fn(Fx::Stat) -> Y, { - let stat = extract_stat(x, stat_ctor); + let stat = extract_stat(pr, x); f_stat(stat) } diff --git a/src/data/stat/bernoulli.rs b/src/data/stat/bernoulli.rs index 2c15956..1f84664 100644 --- a/src/data/stat/bernoulli.rs +++ b/src/data/stat/bernoulli.rs @@ -121,6 +121,11 @@ impl SuffStat for BernoulliSuffStat { self.k -= 1 } } + + fn merge(&mut self, other: Self) { + self.n += other.n; + self.k += other.k; + } } #[cfg(test)] diff --git a/src/data/stat/beta.rs b/src/data/stat/beta.rs index bd990f6..27c48ec 100644 --- a/src/data/stat/beta.rs +++ b/src/data/stat/beta.rs @@ -131,6 +131,12 @@ macro_rules! impl_suffstat { self.sum_ln_1mx = 0.0; } } + + fn merge(&mut self, other: Self) { + self.n += other.n; + self.sum_ln_x += other.sum_ln_x; + self.sum_ln_1mx += other.sum_ln_1mx; + } } }; } diff --git a/src/data/stat/categorical.rs b/src/data/stat/categorical.rs index 4d24b3f..174d421 100644 --- a/src/data/stat/categorical.rs +++ b/src/data/stat/categorical.rs @@ -118,6 +118,16 @@ impl SuffStat for CategoricalSuffStat { self.n -= 1; self.counts[ix] -= 1.0; } + + fn merge(&mut self, other: Self) { + self.n += other.n; + self.counts + .iter_mut() + .zip(other.counts.iter().copied()) + .for_each(|(ct, ct_o)| { + *ct += ct_o; + }); + } } #[cfg(test)] diff --git a/src/data/stat/gaussian.rs b/src/data/stat/gaussian.rs index 7a80eeb..c3ba7fb 100644 --- a/src/data/stat/gaussian.rs +++ b/src/data/stat/gaussian.rs @@ -62,6 +62,11 @@ impl GaussianSuffStat { let nf = self.n as f64; (self.mean() * self.mean()).mul_add(nf, self.sx) } + + #[inline] + pub fn sum_sq_diff(&self) -> f64 { + self.sx + } } impl Default for GaussianSuffStat { @@ -116,21 +121,30 @@ macro_rules! impl_gaussian_suffstat { fn observe(&mut self, x: &$kind) { let xf = f64::from(*x); - let n = self.n; - let mean = self.mean; - let sx = self.sx; + if self.n == 0 { + *self = GaussianSuffStat { + n: 1, + mean: xf, + sx: 0.0, + }; + } else { + let n = self.n; + let mean = self.mean; + let sx = self.sx; - let n1 = n + 1; - let mean_xn = (xf - mean).mul_add((n1 as f64).recip(), mean); + let n1 = n + 1; + let mean_xn = + (xf - mean).mul_add((n1 as f64).recip(), mean); - self.n = n + 1; - self.mean = mean_xn; - self.sx = (xf - mean).mul_add(xf - mean_xn, sx); + self.n = n + 1; + self.mean = mean_xn; + self.sx = (xf - mean).mul_add(xf - mean_xn, sx); + } } - #[inline] fn forget(&mut self, x: &$kind) { let n = self.n; + if n > 1 { let xf = f64::from(*x); let mean = self.mean; @@ -156,6 +170,27 @@ macro_rules! impl_gaussian_suffstat { }; } } + + fn merge(&mut self, other: Self) { + if other.n == 0 { + return; + } + let n1 = self.n as f64; + let n2 = other.n as f64; + let m1 = self.mean; + let m2 = other.mean; + let sum = n1 + n2; + + let mean = n1.mul_add(m1, n2 * m2) / sum; + + let d1 = m1 - mean; + let d2 = m2 - mean; + let sx = self.sx + other.sx + n1 * d1 * d1 + n2 * d2 * d2; + + self.mean = mean; + self.sx = sx; + self.n += other.n; + } } }; } @@ -165,6 +200,8 @@ impl_gaussian_suffstat!(f64); #[cfg(test)] mod tests { + use crate::traits::Sampleable; + use super::*; #[test] @@ -207,4 +244,33 @@ mod tests { assert::close(suffstat.sum_x(), 8.1, 1e-14); assert::close(suffstat.sum_x_sq(), 27.889_999_999_999_993, 1e-13); } + + #[test] + fn incremental_merge() { + let mut rng = rand::thread_rng(); + let g = crate::dist::Gaussian::standard(); + + let xs: Vec = g.sample(5, &mut rng); + let stat_a = { + let mut stat = GaussianSuffStat::new(); + stat.observe_many(&xs); + stat + }; + + let mut stat_b = { + let mut stat = GaussianSuffStat::new(); + stat.observe(&xs[0]); + stat + }; + + for x in xs.iter().skip(1) { + let mut stat_temp = GaussianSuffStat::new(); + stat_temp.observe(x); + >::merge(&mut stat_b, stat_temp); + } + + assert_eq!(stat_a.n, stat_b.n); + assert::close(stat_a.mean, stat_b.mean, 1e-10); + assert::close(stat_a.sx, stat_b.sx, 1e-10); + } } diff --git a/src/data/stat/invgamma.rs b/src/data/stat/invgamma.rs index 87646b3..fadf145 100644 --- a/src/data/stat/invgamma.rs +++ b/src/data/stat/invgamma.rs @@ -133,6 +133,12 @@ macro_rules! impl_suffstat { self.sum_inv_x = 0.0; } } + + fn merge(&mut self, other: Self) { + self.n += other.n; + self.sum_ln_x += other.sum_ln_x; + self.sum_inv_x += other.sum_inv_x; + } } }; } diff --git a/src/data/stat/invgaussian.rs b/src/data/stat/invgaussian.rs index cb1f912..3ff50c2 100644 --- a/src/data/stat/invgaussian.rs +++ b/src/data/stat/invgaussian.rs @@ -157,6 +157,12 @@ macro_rules! impl_invgaussian_suffstat { self.sum_ln_x = 0.0; } } + fn merge(&mut self, other: Self) { + self.n += other.n; + self.sum_x += other.sum_x; + self.sum_inv_x += other.sum_inv_x; + self.sum_ln_x += other.sum_ln_x; + } } }; } diff --git a/src/data/stat/mvg.rs b/src/data/stat/mvg.rs index 2b1c18a..cefe343 100644 --- a/src/data/stat/mvg.rs +++ b/src/data/stat/mvg.rs @@ -80,4 +80,10 @@ impl SuffStat> for MvGaussianSuffStat { self.sum_x_sq = DMatrix::zeros(dims, dims); } } + + fn merge(&mut self, other: Self) { + self.n += other.n; + self.sum_x += other.sum_x; + self.sum_x_sq += other.sum_x_sq; + } } diff --git a/src/data/stat/poisson.rs b/src/data/stat/poisson.rs index 38733eb..b8c35e3 100644 --- a/src/data/stat/poisson.rs +++ b/src/data/stat/poisson.rs @@ -113,6 +113,12 @@ macro_rules! impl_poisson_suffstat { self.sum_ln_fact = 0.0; } } + + fn merge(&mut self, other: Self) { + self.n += other.n; + self.sum += other.sum; + self.sum_ln_fact += other.sum_ln_fact; + } } }; } diff --git a/src/data/stat/unit_powerlaw.rs b/src/data/stat/unit_powerlaw.rs index e33020e..0b0c929 100644 --- a/src/data/stat/unit_powerlaw.rs +++ b/src/data/stat/unit_powerlaw.rs @@ -124,6 +124,11 @@ macro_rules! impl_suffstat { self.sum_ln_x -= xs.iter().map(|x| f64::from(*x)).product::().ln(); } + + fn merge(&mut self, other: Self) { + self.n += other.n; + self.sum_ln_x += other.sum_ln_x; + } } }; } diff --git a/src/dist/beta/bernoulli_prior.rs b/src/dist/beta/bernoulli_prior.rs index d58a573..3dd0108 100644 --- a/src/dist/beta/bernoulli_prior.rs +++ b/src/dist/beta/bernoulli_prior.rs @@ -31,6 +31,10 @@ impl ConjugatePrior for Beta { type MCache = f64; type PpCache = (f64, f64); + fn empty_stat(&self) -> >::Stat { + BernoulliSuffStat::new() + } + #[allow(clippy::many_single_char_names)] fn posterior(&self, x: &DataOrSuffStat) -> Self { let (n, k) = match x { @@ -39,7 +43,9 @@ impl ConjugatePrior for Beta { xs.iter().for_each(|x| stat.observe(x)); (stat.n(), stat.k()) } - DataOrSuffStat::SuffStat(stat) => (stat.n(), stat.k()), + DataOrSuffStat::SuffStat(stat) => { + (>::n(stat), stat.k()) + } }; let a = self.alpha() + k as f64; diff --git a/src/dist/dirichlet/categorical_prior.rs b/src/dist/dirichlet/categorical_prior.rs index 380d29a..3ebb707 100644 --- a/src/dist/dirichlet/categorical_prior.rs +++ b/src/dist/dirichlet/categorical_prior.rs @@ -26,17 +26,17 @@ impl ConjugatePrior type MCache = f64; type PpCache = (Vec, f64); + fn empty_stat(&self) -> >::Stat { + CategoricalSuffStat::new(self.k()) + } + fn posterior(&self, x: &CategoricalData) -> Self::Posterior { - extract_stat_then( - x, - || CategoricalSuffStat::new(self.k()), - |stat: CategoricalSuffStat| { - let alphas: Vec = - stat.counts().iter().map(|&ct| self.alpha() + ct).collect(); - - Dirichlet::new(alphas).unwrap() - }, - ) + extract_stat_then(self, x, |stat: CategoricalSuffStat| { + let alphas: Vec = + stat.counts().iter().map(|&ct| self.alpha() + ct).collect(); + + Dirichlet::new(alphas).unwrap() + }) } #[inline] @@ -54,20 +54,15 @@ impl ConjugatePrior ) -> f64 { let sum_alpha = self.alpha() * self.k() as f64; - extract_stat_then( - x, - || CategoricalSuffStat::new(self.k()), - |stat: CategoricalSuffStat| { - // terms - let b = ln_gammafn(sum_alpha + stat.n() as f64); - let c = stat - .counts() - .iter() - .fold(0.0, |acc, &ct| acc + ln_gammafn(self.alpha() + ct)); - - -b + c + cache - }, - ) + extract_stat_then(self, x, |stat: CategoricalSuffStat| { + let b = ln_gammafn(sum_alpha + stat.n() as f64); + let c = stat + .counts() + .iter() + .fold(0.0, |acc, &ct| acc + ln_gammafn(self.alpha() + ct)); + + -b + c + cache + }) } #[inline] @@ -101,21 +96,21 @@ impl ConjugatePrior for Dirichlet { type MCache = (f64, f64); type PpCache = (Vec, f64); + fn empty_stat(&self) -> >::Stat { + CategoricalSuffStat::new(self.k()) + } + fn posterior(&self, x: &CategoricalData) -> Self::Posterior { - extract_stat_then( - x, - || CategoricalSuffStat::new(self.k()), - |stat: CategoricalSuffStat| { - let alphas: Vec = self - .alphas() - .iter() - .zip(stat.counts().iter()) - .map(|(&a, &ct)| a + ct) - .collect(); - - Dirichlet::new(alphas).unwrap() - }, - ) + extract_stat_then(self, x, |stat: CategoricalSuffStat| { + let alphas: Vec = self + .alphas() + .iter() + .zip(stat.counts().iter()) + .map(|(&a, &ct)| a + ct) + .collect(); + + Dirichlet::new(alphas).unwrap() + }) } #[inline] @@ -135,22 +130,17 @@ impl ConjugatePrior for Dirichlet { x: &CategoricalData, ) -> f64 { let (sum_alpha, ln_norm) = cache; - extract_stat_then( - x, - || CategoricalSuffStat::new(self.k()), - |stat: CategoricalSuffStat| { - // terms - let b = ln_gammafn(sum_alpha + stat.n() as f64); - let c = self - .alphas() - .iter() - .zip(stat.counts().iter()) - .map(|(&a, &ct)| ln_gammafn(a + ct)) - .sum::(); - - -b + c + ln_norm - }, - ) + extract_stat_then(self, x, |stat: CategoricalSuffStat| { + let b = ln_gammafn(sum_alpha + stat.n() as f64); + let c = self + .alphas() + .iter() + .zip(stat.counts().iter()) + .map(|(&a, &ct)| ln_gammafn(a + ct)) + .sum::(); + + -b + c + ln_norm + }) } #[inline] diff --git a/src/dist/gamma/poisson_prior.rs b/src/dist/gamma/poisson_prior.rs index fffa9f8..6ad972c 100644 --- a/src/dist/gamma/poisson_prior.rs +++ b/src/dist/gamma/poisson_prior.rs @@ -3,7 +3,7 @@ use rand::Rng; use crate::data::PoissonSuffStat; use crate::dist::poisson::PoissonError; use crate::dist::{Gamma, Poisson}; -use crate::misc::ln_binom; +use crate::misc::ln_gammafn; use crate::traits::*; impl HasDensity for Gamma { @@ -46,6 +46,10 @@ macro_rules! impl_traits { type MCache = f64; type PpCache = (f64, f64, f64); + fn empty_stat(&self) -> >::Stat { + PoissonSuffStat::new() + } + fn posterior(&self, x: &DataOrSuffStat<$kind, Poisson>) -> Self { let (n, sum) = match x { DataOrSuffStat::Data(ref xs) => { @@ -53,9 +57,10 @@ macro_rules! impl_traits { xs.iter().for_each(|x| stat.observe(x)); (stat.n(), stat.sum()) } - DataOrSuffStat::SuffStat(ref stat) => { - (stat.n(), stat.sum()) - } + DataOrSuffStat::SuffStat(ref stat) => ( + >::n(stat), + stat.sum(), + ), }; let a = self.shape() + sum; @@ -104,7 +109,11 @@ macro_rules! impl_traits { let post = self.posterior(x); let r = post.shape(); let p = 1.0 / (1.0 + post.rate()); - (r, p, p.ln()) + let ln_p = -post.rate().ln_1p(); + let ln_gamma_r = ln_gammafn(post.shape()); + + let z = (1.0 - p).ln().mul_add(r, -ln_gamma_r); + (z, r, ln_p) } fn ln_pp_with_cache( @@ -112,10 +121,10 @@ macro_rules! impl_traits { cache: &Self::PpCache, y: &$kind, ) -> f64 { - let (r, p, ln_p) = cache; + let (z, r, ln_p) = cache; let k = f64::from(*y); - let bnp = ln_binom(k + r - 1.0, k); - bnp + (1.0 - p).ln() * r + k * ln_p + let bnp = ln_gammafn(k + r) - ln_gammafn(k + 1.0); + z + k * ln_p + bnp } } }; diff --git a/src/dist/niw/mvg_prior.rs b/src/dist/niw/mvg_prior.rs index b93a24f..838bbdc 100644 --- a/src/dist/niw/mvg_prior.rs +++ b/src/dist/niw/mvg_prior.rs @@ -3,6 +3,7 @@ use crate::data::{extract_stat_then, DataOrSuffStat, MvGaussianSuffStat}; use crate::dist::{MvGaussian, NormalInvWishart}; use crate::misc::lnmv_gamma; use crate::traits::ConjugatePrior; +use crate::traits::HasSuffStat; use crate::traits::SuffStat; use nalgebra::{DMatrix, DVector}; use std::f64::consts::{LN_2, PI}; @@ -25,44 +26,41 @@ impl ConjugatePrior, MvGaussian> for NormalInvWishart { type MCache = f64; type PpCache = (Self, f64); + fn empty_stat(&self) -> >>::Stat { + MvGaussianSuffStat::new(self.ndims()) + } + fn posterior(&self, x: &MvgData) -> NormalInvWishart { if x.n() == 0 { return self.clone(); } let nf = x.n() as f64; - extract_stat_then( - x, - || MvGaussianSuffStat::new(self.ndims()), - |stat: MvGaussianSuffStat| { - let xbar = stat.sum_x() / stat.n() as f64; - let diff = &xbar - self.mu(); - // s = \sum_{i=1}^N (x_i - \bar{x}) (x_i - \bar{x})^T - // = \sum_{i=1}^N (x_i x_i^T - x_i \bar{x}^T - \bar{x} x_i^T + \bar{x}\bar{x}^T) - // = N \bar{x} \bar{x}^T + \sum_{i=1}^N (x_i x_i^T - x_i \bar{x}^T - \bar{x} x_i^T) - // = N \bar{x} \bar{x}^T + \sum_{i=1}^N x_i x_i^T - // - (\sum_{i=1}^N x_i) \bar{x}^T - \bar{x} (\sum_{i=1}^N x_i^T) - let s: DMatrix = stat.sum_x_sq() - + nf * (&xbar * &xbar.transpose()) - - stat.sum_x() * &xbar.transpose() - - &xbar * stat.sum_x().transpose(); - - let kn = self.k() + stat.n() as f64; - let vn = self.df() + stat.n(); - let mn = (self.k() * self.mu() + stat.sum_x()) / kn; - let sn = self.scale() - + s - + (self.k() * stat.n() as f64) / kn - * &diff - * &diff.transpose(); - - NormalInvWishart::new(mn, kn, vn, sn) - .expect("Invalid posterior parameters") - }, - ) + extract_stat_then(self, x, |stat: MvGaussianSuffStat| { + let xbar = stat.sum_x() / stat.n() as f64; + let diff = &xbar - self.mu(); + // s = \sum_{i=1}^N (x_i - \bar{x}) (x_i - \bar{x})^T + // = \sum_{i=1}^N (x_i x_i^T - x_i \bar{x}^T - \bar{x} x_i^T + \bar{x}\bar{x}^T) + // = N \bar{x} \bar{x}^T + \sum_{i=1}^N (x_i x_i^T - x_i \bar{x}^T - \bar{x} x_i^T) + // = N \bar{x} \bar{x}^T + \sum_{i=1}^N x_i x_i^T + // - (\sum_{i=1}^N x_i) \bar{x}^T - \bar{x} (\sum_{i=1}^N x_i^T) + let s: DMatrix = stat.sum_x_sq() + + nf * (&xbar * &xbar.transpose()) + - stat.sum_x() * &xbar.transpose() + - &xbar * stat.sum_x().transpose(); + + let kn = self.k() + stat.n() as f64; + let vn = self.df() + stat.n(); + let mn = (self.k() * self.mu() + stat.sum_x()) / kn; + let sn = self.scale() + + s + + (self.k() * stat.n() as f64) / kn * &diff * &diff.transpose(); + + NormalInvWishart::new(mn, kn, vn, sn) + .expect("Invalid posterior parameters") + }) } - #[inline] fn ln_m_cache(&self) -> f64 { ln_z(self.k(), self.df(), self.scale()) } @@ -76,7 +74,6 @@ impl ConjugatePrior, MvGaussian> for NormalInvWishart { (nd / 2.0).mul_add(-LN_2PI, zn - z0) } - #[inline] fn ln_pp_cache(&self, x: &MvgData) -> Self::PpCache { let post = self.posterior(x); let zn = ln_z(post.k(), post.df(), post.scale()); diff --git a/src/dist/normal_gamma.rs b/src/dist/normal_gamma.rs index af272c8..a77d48e 100644 --- a/src/dist/normal_gamma.rs +++ b/src/dist/normal_gamma.rs @@ -408,4 +408,32 @@ impl fmt::Display for NormalGammaError { } } +macro_rules! dos_to_post { + (# $self: ident, $stat: ident) => {{ + match $stat { + DataOrSuffStat::SuffStat(stat) => ( + >::n(stat), + posterior_from_stat($self, &stat), + ), + DataOrSuffStat::Data(ref xs) => { + let mut stat = GaussianSuffStat::new(); + stat.observe_many(xs); + (stat.n(), posterior_from_stat($self, &stat)) + } + } + }}; + ($self: ident, $stat: ident) => {{ + match $stat { + DataOrSuffStat::SuffStat(stat) => posterior_from_stat($self, &stat), + DataOrSuffStat::Data(ref xs) => { + let mut stat = GaussianSuffStat::new(); + stat.observe_many(xs); + posterior_from_stat($self, &stat) + } + } + }}; +} + +pub(crate) use dos_to_post; + // TODO: tests! diff --git a/src/dist/normal_gamma/gaussian_prior.rs b/src/dist/normal_gamma/gaussian_prior.rs index 0e407de..d5870ba 100644 --- a/src/dist/normal_gamma/gaussian_prior.rs +++ b/src/dist/normal_gamma/gaussian_prior.rs @@ -1,8 +1,9 @@ use std::collections::BTreeMap; use std::f64::consts::LN_2; +use super::dos_to_post; use crate::consts::*; -use crate::data::{extract_stat, extract_stat_then, GaussianSuffStat}; +use crate::data::{extract_stat, GaussianSuffStat}; use crate::dist::{Gaussian, NormalGamma}; use crate::gaussian_prior_geweke_testable; use crate::misc::ln_gammafn; @@ -14,34 +15,50 @@ fn ln_z(r: f64, s: f64, v: f64) -> f64 { // This is what is should be in clearer, normal, operations // (v + 1.0) / 2.0 * LN_2 + HALF_LN_PI - 0.5 * r.ln() - (v / 2.0) * s.ln() // + ln_gammafn(v / 2.0).0 - // ... and here is what is is when we use mul_add to reduce rounding errors + // ... and here is what it is when we use mul_add to reduce rounding errors let half_v = 0.5 * v; (half_v + 0.5).mul_add(LN_2, HALF_LN_PI) - 0.5_f64.mul_add(r.ln(), half_v.mul_add(s.ln(), -ln_gammafn(half_v))) } +pub struct PosteriorParameters { + pub m: f64, + pub r: f64, + pub s: f64, + pub v: f64, +} + +impl From for NormalGamma { + fn from(PosteriorParameters { m, r, s, v }: PosteriorParameters) -> Self { + NormalGamma::new(m, r, s, v).unwrap() + } +} + fn posterior_from_stat( ng: &NormalGamma, stat: &GaussianSuffStat, -) -> NormalGamma { +) -> PosteriorParameters { let nf = stat.n() as f64; let r = ng.r() + nf; let v = ng.v() + nf; let m = ng.m().mul_add(ng.r(), stat.sum_x()) / r; let s = ng.s() + stat.sum_x_sq() + ng.r().mul_add(ng.m() * ng.m(), -r * m * m); - NormalGamma::new(m, r, s, v).expect("Invalid posterior params.") + + PosteriorParameters { m, r, s, v } } impl ConjugatePrior for NormalGamma { type Posterior = Self; type MCache = f64; - type PpCache = (GaussianSuffStat, f64); + type PpCache = (PosteriorParameters, f64); + + fn empty_stat(&self) -> >::Stat { + GaussianSuffStat::new() + } fn posterior(&self, x: &DataOrSuffStat) -> Self { - extract_stat_then(x, GaussianSuffStat::new, |stat: GaussianSuffStat| { - posterior_from_stat(self, &stat) - }) + dos_to_post!(self, x).into() } #[inline] @@ -54,31 +71,36 @@ impl ConjugatePrior for NormalGamma { cache: &Self::MCache, x: &DataOrSuffStat, ) -> f64 { - extract_stat_then(x, GaussianSuffStat::new, |stat: GaussianSuffStat| { - let post = posterior_from_stat(self, &stat); - let lnz_n = ln_z(post.r, post.s, post.v); - (-(stat.n() as f64)).mul_add(HALF_LN_2PI, lnz_n) - cache - }) + let (n, post) = dos_to_post!(# self, x); + let lnz_n = ln_z(post.r, post.s, post.v); + (-(n as f64)).mul_add(HALF_LN_2PI, lnz_n) - cache } - #[inline] fn ln_pp_cache(&self, x: &DataOrSuffStat) -> Self::PpCache { - let stat = extract_stat(x, GaussianSuffStat::new); - let post_n = posterior_from_stat(self, &stat); - let lnz_n = ln_z(post_n.r, post_n.s, post_n.v); - (stat, lnz_n) - } + let stat = extract_stat(self, x); - fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &f64) -> f64 { - let mut stat = cache.0; - let lnz_n = cache.1; + let params = posterior_from_stat(self, &stat); + let PosteriorParameters { r, s, v, .. } = params; - stat.observe(y); - let post_m = posterior_from_stat(self, &stat); + let half_v = v / 2.0; + let g_ratio = ln_gammafn(half_v + 0.5) - ln_gammafn(half_v); + let term = 0.5_f64.mul_add(LN_2, -HALF_LN_2PI) + + 0.5_f64.mul_add( + (r / (r + 1_f64)).ln(), + half_v.mul_add(s.ln(), g_ratio), + ); - let lnz_m = ln_z(post_m.r(), post_m.s(), post_m.v()); + (params, term) + } + + fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &f64) -> f64 { + let PosteriorParameters { m, r, s, v } = cache.0; - -HALF_LN_2PI + lnz_m - lnz_n + let y = *y; + let rn = r + 1.0; + let mn = r.mul_add(m, y) / rn; + let sn = (rn * mn).mul_add(-mn, (r * m).mul_add(m, y.mul_add(y, s))); + ((v + 1.0) / 2.0).mul_add(-sn.ln(), cache.1) } } diff --git a/src/dist/normal_inv_chi_squared/gaussian_prior.rs b/src/dist/normal_inv_chi_squared/gaussian_prior.rs index c340994..2e93ce1 100644 --- a/src/dist/normal_inv_chi_squared/gaussian_prior.rs +++ b/src/dist/normal_inv_chi_squared/gaussian_prior.rs @@ -1,57 +1,77 @@ use std::collections::BTreeMap; +use std::f64::consts::PI; use crate::consts::HALF_LN_PI; use crate::data::{extract_stat, extract_stat_then, GaussianSuffStat}; use crate::dist::{Gaussian, NormalInvChiSquared}; use crate::gaussian_prior_geweke_testable; - +use crate::misc::ln_gammafn; use crate::test::GewekeTestable; use crate::traits::*; +#[derive(Clone, Debug)] +pub struct PosteriorParameters { + pub mn: f64, + pub kn: f64, + pub vn: f64, + pub s2n: f64, +} + +impl From for NormalInvChiSquared { + fn from( + PosteriorParameters { mn, kn, vn, s2n }: PosteriorParameters, + ) -> Self { + NormalInvChiSquared::new(mn, kn, vn, s2n).unwrap() + } +} + // XXX: Check out section 6.3 from Kevin Murphy's paper // https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf fn posterior_from_stat( nix: &NormalInvChiSquared, stat: &GaussianSuffStat, -) -> NormalInvChiSquared { +) -> PosteriorParameters { + let (m, k, v, s2) = nix.params(); + if stat.n() == 0 { - return nix.clone(); + return PosteriorParameters { + mn: m, + kn: k, + vn: v, + s2n: s2, + }; } let n = stat.n() as f64; - let (m, k, v, s2) = nix.params(); - let xbar = stat.mean(); - let sum_x = xbar * n; - // Sum (x - xbar)^2 - // = Sum[ x*x - 2x*xbar + xbar*xbar ] - // = Sum[x^2] + n * xbar^2 - 2 * xbar + Sum[x] - // = Sum[x^2] + n * xbar^2 - 2 * n *xbar^2 - // = Sum[x^2] - n * xbar^2 - let mid = (n * xbar).mul_add(-xbar, stat.sum_x_sq()); let kn = k + n; - let divby_k_plus_n = kn.recip(); + let kn_recip = kn.recip(); let vn = v + n; - let mn = k.mul_add(m, sum_x) * divby_k_plus_n; + let mn = k.mul_add(m, stat.sum_x()) * kn_recip; let diff_m_xbar = m - xbar; - let s2n = ((n * k * divby_k_plus_n) * diff_m_xbar) - .mul_add(diff_m_xbar, v.mul_add(s2, mid)) - / vn; + let s2n = v.mul_add( + s2, + ((n * k * kn_recip) * diff_m_xbar) + .mul_add(diff_m_xbar, stat.sum_sq_diff()), + ) / vn; - NormalInvChiSquared::new(mn, kn, vn, s2n) - .expect("Invalid posterior params.") + PosteriorParameters { mn, kn, vn, s2n } } impl ConjugatePrior for NormalInvChiSquared { type Posterior = Self; type MCache = f64; - type PpCache = (GaussianSuffStat, f64); + type PpCache = (PosteriorParameters, f64); + + fn empty_stat(&self) -> >::Stat { + GaussianSuffStat::new() + } fn posterior(&self, x: &DataOrSuffStat) -> Self { - extract_stat_then(x, GaussianSuffStat::new, |stat: GaussianSuffStat| { - posterior_from_stat(self, &stat) + extract_stat_then(self, x, |stat: GaussianSuffStat| { + posterior_from_stat(self, &stat).into() }) } @@ -65,33 +85,38 @@ impl ConjugatePrior for NormalInvChiSquared { cache: &Self::MCache, x: &DataOrSuffStat, ) -> f64 { - extract_stat_then(x, GaussianSuffStat::new, |stat: GaussianSuffStat| { + extract_stat_then(self, x, |stat: GaussianSuffStat| { let n = stat.n() as f64; - let post = posterior_from_stat(self, &stat); + let post: Self = posterior_from_stat(self, &stat).into(); let lnz_n = post.ln_z(); n.mul_add(-HALF_LN_PI, lnz_n - cache) }) } - #[inline] fn ln_pp_cache(&self, x: &DataOrSuffStat) -> Self::PpCache { - let stat = extract_stat(x, GaussianSuffStat::new); - let post_n = posterior_from_stat(self, &stat); - let lnz_n = post_n.ln_z(); - (stat, lnz_n) - // post_n + let stat = extract_stat(self, x); + let post = posterior_from_stat(self, &stat); + let kn = post.kn; + let vn = post.vn; + + let z = 0.5_f64.mul_add( + (kn / ((kn + 1.0) * PI * vn * post.s2n)).ln(), + ln_gammafn((vn + 1.0) / 2.0) - ln_gammafn(vn / 2.0), + ); + (post, z) } fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &f64) -> f64 { - let mut stat = cache.0; - let lnz_n = cache.1; + let post = &cache.0; + let z = cache.1; + let kn = post.kn; - stat.observe(y); - let post_m = posterior_from_stat(self, &stat); + let diff = y - post.mn; - let lnz_m = post_m.ln_z(); - - -HALF_LN_PI + lnz_m - lnz_n + ((post.vn + 1.0) / 2.0).mul_add( + -((kn * diff * diff) / ((kn + 1.0) * post.vn * post.s2n)).ln_1p(), + z, + ) } } diff --git a/src/dist/normal_inv_gamma.rs b/src/dist/normal_inv_gamma.rs index 8cdeba5..c1e761c 100644 --- a/src/dist/normal_inv_gamma.rs +++ b/src/dist/normal_inv_gamma.rs @@ -7,11 +7,12 @@ use serde::{Deserialize, Serialize}; mod gaussian_prior; +use rand::Rng; +use std::fmt; + use crate::dist::{Gaussian, InvGamma}; use crate::impl_display; use crate::traits::*; -use rand::Rng; -use std::fmt; /// Prior for Gaussian /// diff --git a/src/dist/normal_inv_gamma/gaussian_prior.rs b/src/dist/normal_inv_gamma/gaussian_prior.rs index b4498bf..0d3fc80 100644 --- a/src/dist/normal_inv_gamma/gaussian_prior.rs +++ b/src/dist/normal_inv_gamma/gaussian_prior.rs @@ -2,6 +2,7 @@ use std::collections::BTreeMap; use crate::consts::HALF_LN_2PI; use crate::data::{extract_stat, extract_stat_then, GaussianSuffStat}; +use crate::dist::normal_gamma::dos_to_post; use crate::dist::{Gaussian, NormalInvGamma}; use crate::gaussian_prior_geweke_testable; use crate::misc::ln_gammafn; @@ -10,18 +11,30 @@ use crate::traits::*; #[inline] fn ln_z(v: f64, a: f64, b: f64) -> f64 { - // -(a * b.ln() - 0.5 * v.ln() - a.ln_gamma().0) let p1 = v.ln().mul_add(0.5, ln_gammafn(a)); -b.ln().mul_add(a, -p1) } +pub struct PosteriorParameters { + m: f64, + v: f64, + a: f64, + b: f64, +} + +impl From for NormalInvGamma { + fn from(PosteriorParameters { m, v, a, b }: PosteriorParameters) -> Self { + NormalInvGamma::new(m, v, a, b).unwrap() + } +} + // XXX: Check out section 6.3 from Kevin Murphy's paper // https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf #[allow(clippy::many_single_char_names)] fn posterior_from_stat( nig: &NormalInvGamma, stat: &GaussianSuffStat, -) -> NormalInvGamma { +) -> PosteriorParameters { let n = stat.n() as f64; let super::NormalInvGammaParameters { m, v, a, b } = nig.emit_params(); @@ -30,27 +43,30 @@ fn posterior_from_stat( let vn_inv = v_inv + n; let vn = vn_inv.recip(); - // let mn = (v_inv * m + stat.sum_x()) * vn; let mn = v_inv.mul_add(m, stat.sum_x()) / vn_inv; - // let an = a + 0.5 * n; let an = n.mul_add(0.5, a); - // let bn = b + 0.5 * (m * m * v_inv + stat.sum_x_sq() - mn * mn * vn_inv); let p1 = (m * m).mul_add(v_inv, stat.sum_x_sq()); let bn = (-mn * mn).mul_add(vn_inv, p1).mul_add(0.5, b); - NormalInvGamma::new(mn, vn, an, bn).expect("Invalid posterior params.") + PosteriorParameters { + m: mn, + v: vn, + a: an, + b: bn, + } } impl ConjugatePrior for NormalInvGamma { type Posterior = Self; type MCache = f64; - type PpCache = (GaussianSuffStat, f64); - // type PpCache = NormalInvGamma; + type PpCache = (PosteriorParameters, f64); + + fn empty_stat(&self) -> >::Stat { + GaussianSuffStat::new() + } fn posterior(&self, x: &DataOrSuffStat) -> Self { - extract_stat_then(x, GaussianSuffStat::new, |stat: GaussianSuffStat| { - posterior_from_stat(self, &stat) - }) + dos_to_post!(self, x).into() } #[inline] @@ -63,33 +79,35 @@ impl ConjugatePrior for NormalInvGamma { cache: &Self::MCache, x: &DataOrSuffStat, ) -> f64 { - extract_stat_then(x, GaussianSuffStat::new, |stat: GaussianSuffStat| { - let post = posterior_from_stat(self, &stat); - let n = stat.n() as f64; - let lnz_n = ln_z(post.v, post.a, post.b); - n.mul_add(-HALF_LN_2PI, lnz_n - cache) - // lnz_n - cache - n * HALF_LN_PI - n*LN_2 - }) + let (n, post) = dos_to_post!(# self, x); + let lnz_n = ln_z(post.v, post.a, post.b); + (n as f64).mul_add(-HALF_LN_2PI, lnz_n - cache) } - #[inline] fn ln_pp_cache(&self, x: &DataOrSuffStat) -> Self::PpCache { - let stat = extract_stat(x, GaussianSuffStat::new); - let post_n = posterior_from_stat(self, &stat); - let lnz_n = ln_z(post_n.v, post_n.a, post_n.b); - (stat, lnz_n) + let params = dos_to_post!(self, x); + let PosteriorParameters { v, a, b, .. } = params; + + let gamma_ratio = ln_gammafn(a + 0.5) - ln_gammafn(a); + let z = (-0.5_f64).mul_add(v.ln_1p(), a * b.ln()) + gamma_ratio + - HALF_LN_2PI; + + (params, z) } fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &f64) -> f64 { - let mut stat = cache.0; - let lnz_n = cache.1; - - stat.observe(y); - let post_m = posterior_from_stat(self, &stat); + let PosteriorParameters { m, v, a, b } = cache.0; - let lnz_m = ln_z(post_m.v, post_m.a, post_m.b); + let y = *y; + let v_recip = v.recip(); + let vn_recip = v_recip + 1.0; + let mn = v_recip.mul_add(m, y) / vn_recip; + let bn = 0.5_f64.mul_add( + (mn * mn).mul_add(-vn_recip, (m * m).mul_add(v_recip, y * y)), + b, + ); - -HALF_LN_2PI + lnz_m - lnz_n + (a + 0.5).mul_add(-bn.ln(), cache.1) } } diff --git a/src/dist/unit_powerlaw/bernoulli_prior.rs b/src/dist/unit_powerlaw/bernoulli_prior.rs index a857bf7..c925cd5 100644 --- a/src/dist/unit_powerlaw/bernoulli_prior.rs +++ b/src/dist/unit_powerlaw/bernoulli_prior.rs @@ -31,6 +31,10 @@ impl ConjugatePrior for UnitPowerLaw { type MCache = f64; type PpCache = (f64, f64); + fn empty_stat(&self) -> >::Stat { + BernoulliSuffStat::new() + } + #[allow(clippy::many_single_char_names)] fn posterior(&self, x: &DataOrSuffStat) -> Beta { let (n, k) = match x { @@ -39,7 +43,9 @@ impl ConjugatePrior for UnitPowerLaw { xs.iter().for_each(|x| stat.observe(x)); (stat.n(), stat.k()) } - DataOrSuffStat::SuffStat(stat) => (stat.n(), stat.k()), + DataOrSuffStat::SuffStat(stat) => { + (>::n(stat), stat.k()) + } }; let a = self.alpha() + k as f64; diff --git a/src/experimental/stick_breaking_process/sbd_stat.rs b/src/experimental/stick_breaking_process/sbd_stat.rs index 3942712..f866eed 100644 --- a/src/experimental/stick_breaking_process/sbd_stat.rs +++ b/src/experimental/stick_breaking_process/sbd_stat.rs @@ -29,6 +29,10 @@ impl StickBreakingDiscreteSuffStat { Self { counts: Vec::new() } } + pub fn from_counts(counts: Vec) -> Self { + Self { counts } + } + /// Calculates break pairs for probabilities. /// /// Returns a vector of pairs where each pair consists of the sum of all counts after the current index and the count at the current index. @@ -156,6 +160,16 @@ impl SuffStat for StickBreakingDiscreteSuffStat { assert!(self.counts[*i] > 0, "No observations of {i} to forget."); self.counts[*i] -= 1; } + + fn merge(&mut self, other: Self) { + if other.counts.len() > self.counts.len() { + self.counts.resize(other.counts.len(), 0); + } + self.counts + .iter_mut() + .zip(other.counts.iter()) + .for_each(|(ct_a, &ct_b)| *ct_a += ct_b); + } } #[cfg(test)] mod tests { diff --git a/src/experimental/stick_breaking_process/stick_breaking.rs b/src/experimental/stick_breaking_process/stick_breaking.rs index 52fa680..ea92971 100644 --- a/src/experimental/stick_breaking_process/stick_breaking.rs +++ b/src/experimental/stick_breaking_process/stick_breaking.rs @@ -1,6 +1,5 @@ use crate::experimental::stick_breaking_process::StickBreakingDiscrete; use crate::experimental::stick_breaking_process::StickBreakingDiscreteSuffStat; -// use crate::experimental::stick_breaking_process::StickBreakingSuffStat; use crate::experimental::stick_breaking_process::StickSequence; use crate::prelude::*; use crate::traits::*; @@ -13,10 +12,10 @@ use special::Beta as BetaFn; #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; +/// Represents a stick-breaking process. #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] #[derive(Clone, Debug, PartialEq)] -/// Represents a stick-breaking process. pub struct StickBreaking { break_prefix: Vec, break_tail: UnitPowerLaw, @@ -241,6 +240,12 @@ impl ConjugatePrior for StickBreaking { type MCache = (); type PpCache = Self::Posterior; + fn empty_stat( + &self, + ) -> >::Stat { + StickBreakingDiscreteSuffStat::new() + } + /// Computes the logarithm of the marginal likelihood cache. fn ln_m_cache(&self) -> Self::MCache {} diff --git a/src/experimental/stick_breaking_process/stick_breaking_stat.rs b/src/experimental/stick_breaking_process/stick_breaking_stat.rs index 1f625cc..b2a05a4 100644 --- a/src/experimental/stick_breaking_process/stick_breaking_stat.rs +++ b/src/experimental/stick_breaking_process/stick_breaking_stat.rs @@ -205,4 +205,14 @@ impl SuffStat<&[f64]> for StickBreakingSuffStat { self.num_breaks -= num_breaks; self.sum_log_q -= sum_log_q; } + + fn merge(&mut self, other: Self) { + if other.n == 0 { + return; + } + self.n += other.n; + self.sum_log_q += other.sum_log_q; + // FIXME: is this right? + self.num_breaks += other.num_breaks; + } } diff --git a/src/experimental/stick_breaking_process/stick_sequence.rs b/src/experimental/stick_breaking_process/stick_sequence.rs index 9f2ce15..e29394c 100644 --- a/src/experimental/stick_breaking_process/stick_sequence.rs +++ b/src/experimental/stick_breaking_process/stick_sequence.rs @@ -60,6 +60,10 @@ impl _Inner { } } + pub fn ccdf(&self) -> &[f64] { + &self.ccdf + } + fn extend + Clone>(&mut self, breaker: &B) -> f64 { let p: f64 = breaker.draw(&mut self.rng); let remaining_mass = self.ccdf.last().unwrap(); diff --git a/src/misc/func.rs b/src/misc/func.rs index c7b73f0..4cbab02 100644 --- a/src/misc/func.rs +++ b/src/misc/func.rs @@ -127,7 +127,9 @@ where let (alpha, r) = self.fold((f64::NEG_INFINITY, 0.0), |(alpha, r), x| { let x = *x.borrow(); - if x <= alpha { + if x == f64::NEG_INFINITY { + return (alpha, r); + } else if x <= alpha { (alpha, r + (x - alpha).exp()) } else { (x, (alpha - x).exp().mul_add(r, 1.0)) diff --git a/src/model.rs b/src/model.rs index efd96e9..e8d31be 100644 --- a/src/model.rs +++ b/src/model.rs @@ -93,23 +93,23 @@ where } } -impl SuffStat for ConjugateModel -where - Fx: Rv + HasSuffStat, - Pr: ConjugatePrior, -{ - fn n(&self) -> usize { - self.suffstat.n() - } +// impl SuffStat for ConjugateModel +// where +// Fx: Rv + HasSuffStat, +// Pr: ConjugatePrior, +// { +// fn n(&self) -> usize { +// self.suffstat.n() +// } - fn observe(&mut self, x: &X) { - self.suffstat.observe(x); - } +// fn observe(&mut self, x: &X) { +// self.suffstat.observe(x); +// } - fn forget(&mut self, x: &X) { - self.suffstat.forget(x); - } -} +// fn forget(&mut self, x: &X) { +// self.suffstat.forget(x); +// } +// } impl HasDensity for ConjugateModel where diff --git a/src/traits.rs b/src/traits.rs index 54e9ea5..ec575b2 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -541,9 +541,10 @@ where /// Type of the cache for the posterior predictive type PpCache; - /// Computes the posterior distribution from the data - // fn posterior(&self, x: &DataOrSuffStat) -> Self::Posterior; + /// Generate and empty sufficient statistic + fn empty_stat(&self) -> Fx::Stat; + /// Computes the posterior distribution from the data fn posterior_from_suffstat(&self, stat: &Fx::Stat) -> Self::Posterior { self.posterior(&DataOrSuffStat::SuffStat(stat)) } @@ -682,4 +683,7 @@ pub trait SuffStat { fn forget_many(&mut self, xs: &[X]) { xs.iter().for_each(|x| self.forget(x)); } + + /// Combine sufficient statistics + fn merge(&mut self, other: Self); }