From 7619f18f04ced40dae3f6b52ffc8d1b4e1409fa6 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Wed, 31 Jul 2024 15:11:15 -0700 Subject: [PATCH 1/9] make logsumexp take an impl Iterator --- src/dist/categorical.rs | 8 ++- src/dist/cauchy.rs | 2 +- src/dist/mixture.rs | 7 ++- src/misc/func.rs | 117 +++++++++++++++++++--------------------- 4 files changed, 63 insertions(+), 71 deletions(-) diff --git a/src/dist/categorical.rs b/src/dist/categorical.rs index e7150b1..66d78d3 100644 --- a/src/dist/categorical.rs +++ b/src/dist/categorical.rs @@ -92,10 +92,8 @@ impl Categorical { } })?; - let ln_weights: Vec = weights.iter().map(|w| w.ln()).collect(); - let ln_norm = logsumexp(&ln_weights); - let normed_weights = - ln_weights.iter().map(|lnw| lnw - ln_norm).collect(); + let ln_norm = weights.iter().sum::().ln(); + let normed_weights = weights.iter().map(|w| w.ln() - ln_norm).collect(); Ok(Categorical::new_unchecked(normed_weights)) } @@ -148,7 +146,7 @@ impl Categorical { } })?; - let sum = logsumexp(&ln_weights).abs(); + let sum = logsumexp(ln_weights.iter().map(|&x| x)).abs(); if sum < 10E-12 { Ok(Categorical { ln_weights }) } else { diff --git a/src/dist/cauchy.rs b/src/dist/cauchy.rs index d0ad0e8..ebbcedb 100644 --- a/src/dist/cauchy.rs +++ b/src/dist/cauchy.rs @@ -220,7 +220,7 @@ macro_rules! impl_traits { ln_scale, ); // TODO: make a logaddexp method for two floats - -logsumexp(&[ln_scale, term]) - LN_PI + -logsumexp([ln_scale, term].into_iter()) - LN_PI } } diff --git a/src/dist/mixture.rs b/src/dist/mixture.rs index 5601e0a..18bf895 100644 --- a/src/dist/mixture.rs +++ b/src/dist/mixture.rs @@ -388,14 +388,13 @@ where Fx: Rv, { fn ln_f(&self, x: &X) -> f64 { - let lfs: Vec = self + let lfs = self .ln_weights() .iter() .zip(self.components.iter()) - .map(|(&w, cpnt)| w + cpnt.ln_f(x)) - .collect(); + .map(|(&w, cpnt)| w + cpnt.ln_f(x)); - logsumexp(&lfs) + logsumexp(lfs) } fn f(&self, x: &X) -> f64 { diff --git a/src/misc/func.rs b/src/misc/func.rs index 2ce1de5..beeb9ba 100644 --- a/src/misc/func.rs +++ b/src/misc/func.rs @@ -93,25 +93,18 @@ pub fn ln_gammafn(x: f64) -> f64 { /// Safely compute `log(sum(exp(xs))` /// Streaming `logexp` implementation as described in [Sebastian Nowozin's blog](https://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html) -pub fn logsumexp(xs: &[f64]) -> f64 { - if xs.is_empty() { - panic!("Empty container"); - } else if xs.len() == 1 { - xs[0] - } else { - let (alpha, r) = - xs.iter().fold((f64::NEG_INFINITY, 0.0), |(alpha, r), &x| { - if x == f64::NEG_INFINITY { - (alpha, r) - } else if x <= alpha { - (alpha, r + (x - alpha).exp()) - } else { - (x, r.mul_add((alpha - x).exp(), 1.0)) - } - }); +pub fn logsumexp(xs: impl Iterator) -> f64 { + let (alpha, r) = xs.fold((f64::NEG_INFINITY, 0.0), |(alpha, r), x| { + if x == f64::NEG_INFINITY { + (alpha, r) + } else if x <= alpha { + (alpha, r + (x - alpha).exp()) + } else { + (x, r.mul_add((alpha - x).exp(), 1.0)) + } + }); - r.ln() + alpha - } + r.ln() + alpha } /// Cumulative sum of `xs` @@ -270,7 +263,11 @@ pub fn ln_pflips( normed: bool, rng: &mut R, ) -> Vec { - let z = if normed { 0.0 } else { logsumexp(ln_weights) }; + let z = if normed { + 0.0 + } else { + logsumexp(ln_weights.iter().map(|&x| x)) + }; // doing this instead of calling pflips shaves about 30% off the runtime. let cws: Vec = ln_weights @@ -828,52 +825,50 @@ mod tests { assert_eq!(argmax(&xs), vec![4, 6]); } - #[test] - fn logsumexp_on_vector_of_zeros() { - let xs: Vec = vec![0.0; 5]; - // should be about log(5) - assert::close(logsumexp(&xs), 1.609_437_912_434_100_3, TOL); - } + use proptest::prelude::*; - #[test] - fn logsumexp_on_random_values() { - let xs: Vec = vec![ - 0.304_153_86, - -0.070_722_96, - -1.042_870_19, - 0.278_554_07, - -0.818_967_65, - ]; - assert::close(logsumexp(&xs), 1.482_000_789_426_305_9, TOL); - } + proptest! { + #[test] + fn proptest_logsumexp(xs in prop::collection::vec(-1e10f64..1e10, 0..100)) { + let result = logsumexp(xs.iter().cloned()); - #[test] - fn logsumexp_returns_only_value_on_one_element_container() { - let xs: Vec = vec![0.304_153_86]; - assert::close(logsumexp(&xs), 0.304_153_86, TOL); - } + if xs.is_empty() { + prop_assert!(result.is_nan()); + } else { + // Naive implementation for comparison + let max_x = xs.iter().cloned().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(); + let sum_exp = xs.iter().map(|&x| (x - max_x).exp()).sum::(); + let expected = max_x + sum_exp.ln(); - #[test] - #[should_panic] - fn logsumexp_should_panic_on_empty() { - let xs: Vec = Vec::new(); - logsumexp(&xs); - } + // Check that the results are close + prop_assert!((result - expected).abs() < 1e-10); - #[test] - fn logsumexp_leading_neginf() { - let inf = f64::INFINITY; - let weights = vec![ - -inf, - -210.148_738_791_973_16, - -818.104_304_460_164_3, - -1_269.048_018_522_644_5, - -2_916.862_476_271_387, - -inf, - ]; - - let lse = logsumexp(&weights); - assert::close(lse, -210.148_738_791_973_16, TOL); + // Check that the result is greater than or equal to the maximum input + prop_assert!(result >= *xs.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap()); + + // Check that exp(result) is greater than or equal to the sum of exp(x) for all x + let sum_exp_inputs: f64 = xs.iter().map(|&x| x.exp()).sum(); + prop_assert!(result.exp() >= sum_exp_inputs); + } + } + + #[test] + fn proptest_logsumexp_with_neg_infinity( + xs in prop::collection::vec(-1e10f64..1e10, 0..99), + neg_inf_count in 0..10usize + ) { + let mut extended_xs = xs.clone(); + extended_xs.extend(std::iter::repeat(f64::NEG_INFINITY).take(neg_inf_count)); + + let result = logsumexp(extended_xs.iter().cloned()); + + if extended_xs.iter().all(|&x| x == f64::NEG_INFINITY) { + prop_assert!(result == f64::NEG_INFINITY); + } else { + let expected = logsumexp(xs.iter().cloned()); + prop_assert!((result - expected).abs() < 1e-10); + } + } } #[test] From f2890b42a524bd5a0e00ea9a35d0a295973afbd9 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Wed, 31 Jul 2024 19:23:02 -0700 Subject: [PATCH 2/9] more logsumexp work --- src/dist/categorical.rs | 18 ++-- src/dist/cauchy.rs | 6 +- src/dist/mixture.rs | 10 +-- src/dist/normal_gamma/gaussian_prior.rs | 35 ++++---- .../normal_inv_chi_squared/gaussian_prior.rs | 36 ++++---- src/dist/normal_inv_gamma/gaussian_prior.rs | 38 ++++----- .../stick_breaking_process/stick_breaking.rs | 4 +- src/misc/func.rs | 83 ++++++++++++++----- tests/mi.rs | 18 ++-- 9 files changed, 138 insertions(+), 110 deletions(-) diff --git a/src/dist/categorical.rs b/src/dist/categorical.rs index 66d78d3..3c7a130 100644 --- a/src/dist/categorical.rs +++ b/src/dist/categorical.rs @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize}; use crate::data::{CategoricalDatum, CategoricalSuffStat}; use crate::impl_display; -use crate::misc::{argmax, ln_pflips, logsumexp, vec_to_string}; +use crate::misc::{argmax, ln_pflips, vec_to_string, LogSumExp}; use crate::traits::*; use rand::Rng; use std::fmt; @@ -146,8 +146,8 @@ impl Categorical { } })?; - let sum = logsumexp(ln_weights.iter().map(|&x| x)).abs(); - if sum < 10E-12 { + let sum = ln_weights.iter().logsumexp(); + if sum.abs() < 1E-12 { Ok(Categorical { ln_weights }) } else { Err(CategoricalError::WeightsDoNotSumToOne { ln: true, sum }) @@ -339,7 +339,11 @@ mod tests { // weights the def do not sum to 1 let weights: Vec = vec![2.0, 1.0, 2.0, 3.0, 1.0]; let cat = Categorical::new(&weights).unwrap(); - assert::close(logsumexp(&cat.ln_weights), 0.0, TOL); + assert::close( + (cat.ln_weights.iter().map(|&ln_w| ln_w)).logsumexp(), + 0.0, + TOL, + ); } #[test] @@ -350,7 +354,11 @@ mod tests { cat.ln_weights .iter() .for_each(|&ln_w| assert::close(ln_w, ln_weight, TOL)); - assert::close(logsumexp(&cat.ln_weights), 0.0, TOL); + assert::close( + (cat.ln_weights.iter().map(|&ln_w| ln_w)).logsumexp(), + 0.0, + TOL, + ); } #[test] diff --git a/src/dist/cauchy.rs b/src/dist/cauchy.rs index ebbcedb..b324053 100644 --- a/src/dist/cauchy.rs +++ b/src/dist/cauchy.rs @@ -4,7 +4,6 @@ use serde::{Deserialize, Serialize}; use crate::consts::LN_PI; use crate::impl_display; -use crate::misc::logsumexp; use crate::traits::*; use rand::Rng; use rand_distr::Cauchy as RCauchy; @@ -209,7 +208,7 @@ impl From<&Cauchy> for String { } impl_display!(Cauchy); - +use crate::misc::logaddexp; macro_rules! impl_traits { ($kind:ty) => { impl HasDensity<$kind> for Cauchy { @@ -219,8 +218,7 @@ macro_rules! impl_traits { ((f64::from(*x) - self.loc).abs().ln() - ln_scale), ln_scale, ); - // TODO: make a logaddexp method for two floats - -logsumexp([ln_scale, term].into_iter()) - LN_PI + -logaddexp(ln_scale, term) - LN_PI } } diff --git a/src/dist/mixture.rs b/src/dist/mixture.rs index 18bf895..f9bfc23 100644 --- a/src/dist/mixture.rs +++ b/src/dist/mixture.rs @@ -6,7 +6,7 @@ use serde::ser::{SerializeStruct, Serializer}; use serde::{Deserialize, Serialize}; use crate::dist::{Categorical, Gaussian, Poisson}; -use crate::misc::{logsumexp, pflips}; +use crate::misc::{pflips, LogSumExp}; use crate::traits::*; use rand::Rng; use std::fmt; @@ -388,13 +388,11 @@ where Fx: Rv, { fn ln_f(&self, x: &X) -> f64 { - let lfs = self - .ln_weights() + self.ln_weights() .iter() .zip(self.components.iter()) - .map(|(&w, cpnt)| w + cpnt.ln_f(x)); - - logsumexp(lfs) + .map(|(&w, cpnt)| w + cpnt.ln_f(x)) + .logsumexp() } fn f(&self, x: &X) -> f64 { diff --git a/src/dist/normal_gamma/gaussian_prior.rs b/src/dist/normal_gamma/gaussian_prior.rs index 87222c3..0e407de 100644 --- a/src/dist/normal_gamma/gaussian_prior.rs +++ b/src/dist/normal_gamma/gaussian_prior.rs @@ -184,7 +184,7 @@ mod tests { #[test] fn ln_m_vs_monte_carlo() { - use crate::misc::logsumexp; + use crate::misc::LogSumExp; let n_samples = 8_000_000; let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; @@ -194,14 +194,13 @@ mod tests { let ln_m = ng.ln_m(&DataOrSuffStat::::from(&xs)); let mc_est = { - let ln_fs: Vec = ng - .sample_stream(&mut rand::thread_rng()) + ng.sample_stream(&mut rand::thread_rng()) .take(n_samples) .map(|gauss: Gaussian| { xs.iter().map(|x| gauss.ln_f(x)).sum::() }) - .collect(); - logsumexp(&ln_fs) - (n_samples as f64).ln() + .logsumexp() + - (n_samples as f64).ln() }; // high error tolerance. MC estimation is not the most accurate... assert::close(ln_m, mc_est, 1e-2); @@ -209,7 +208,7 @@ mod tests { #[test] fn ln_m_vs_importance() { - use crate::misc::logsumexp; + use crate::misc::LogSumExp; let n_samples = 2_000_000; let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; @@ -223,19 +222,17 @@ mod tests { let mut rng = rand::thread_rng(); // let pr_p = Gamma::new(1.6, 2.2).unwrap(); // let pr_m = Gaussian::new(1.0, 2.0).unwrap(); - let ln_fs: Vec = (0..n_samples) - .map(|_| { - // let mu: f64 = pr_m.draw(&mut rng); - // let prec: f64 = pr_p.draw(&mut rng); - // let gauss = Gaussian::new(mu, prec.sqrt().recip()).unwrap(); - let gauss: Gaussian = post.draw(&mut rng); - let ln_f = xs.iter().map(|x| gauss.ln_f(x)).sum::(); - - // ln_f + ng.ln_f(&gauss) - pr_m.ln_f(&mu) - pr_p.ln_f(&prec) - ln_f + ng.ln_f(&gauss) - post.ln_f(&gauss) - }) - .collect(); - logsumexp(&ln_fs) - (n_samples as f64).ln() + let ln_fs = (0..n_samples).map(|_| { + // let mu: f64 = pr_m.draw(&mut rng); + // let prec: f64 = pr_p.draw(&mut rng); + // let gauss = Gaussian::new(mu, prec.sqrt().recip()).unwrap(); + let gauss: Gaussian = post.draw(&mut rng); + let ln_f = xs.iter().map(|x| gauss.ln_f(x)).sum::(); + + // ln_f + ng.ln_f(&gauss) - pr_m.ln_f(&mu) - pr_p.ln_f(&prec) + ln_f + ng.ln_f(&gauss) - post.ln_f(&gauss) + }); + ln_fs.logsumexp() - (n_samples as f64).ln() }; // high error tolerance. MC estimation is not the most accurate... assert::close(ln_m, mc_est, 1e-2); diff --git a/src/dist/normal_inv_chi_squared/gaussian_prior.rs b/src/dist/normal_inv_chi_squared/gaussian_prior.rs index 54f97c1..c340994 100644 --- a/src/dist/normal_inv_chi_squared/gaussian_prior.rs +++ b/src/dist/normal_inv_chi_squared/gaussian_prior.rs @@ -222,7 +222,7 @@ mod test { #[test] fn ln_m_single_datum_vs_monte_carlo() { - use crate::misc::logsumexp; + use crate::misc::LogSumExp; let n_samples = 1_000_000; let x: f64 = -0.3; @@ -233,12 +233,11 @@ mod test { let ln_m = nix.ln_m(&DataOrSuffStat::::from(&xs)); let mc_est = { - let ln_fs: Vec = nix - .sample_stream(&mut rand::thread_rng()) + nix.sample_stream(&mut rand::thread_rng()) .take(n_samples) .map(|gauss: Gaussian| gauss.ln_f(&x)) - .collect(); - logsumexp(&ln_fs) - (n_samples as f64).ln() + .logsumexp() + - (n_samples as f64).ln() }; // high error tolerance. MC estimation is not the most accurate... assert::close(ln_m, mc_est, 1e-2); @@ -246,7 +245,7 @@ mod test { #[test] fn ln_m_vs_monte_carlo() { - use crate::misc::logsumexp; + use crate::misc::LogSumExp; let n_samples = 1_000_000; let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; @@ -256,14 +255,13 @@ mod test { let ln_m = nix.ln_m(&DataOrSuffStat::::from(&xs)); let mc_est = { - let ln_fs: Vec = nix - .sample_stream(&mut rand::thread_rng()) + nix.sample_stream(&mut rand::thread_rng()) .take(n_samples) .map(|gauss: Gaussian| { xs.iter().map(|x| gauss.ln_f(x)).sum::() }) - .collect(); - logsumexp(&ln_fs) - (n_samples as f64).ln() + .logsumexp() + - (n_samples as f64).ln() }; // high error tolerance. MC estimation is not the most accurate... assert::close(ln_m, mc_est, 1e-2); @@ -271,7 +269,7 @@ mod test { #[test] fn ln_pp_vs_monte_carlo() { - use crate::misc::logsumexp; + use crate::misc::LogSumExp; let n_samples = 1_000_000; let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; @@ -283,12 +281,11 @@ mod test { let ln_pp = nix.ln_pp(&y, &DataOrSuffStat::::from(&xs)); let mc_est = { - let ln_fs: Vec = post - .sample_stream(&mut rand::thread_rng()) + post.sample_stream(&mut rand::thread_rng()) .take(n_samples) .map(|gauss: Gaussian| gauss.ln_f(&y)) - .collect(); - logsumexp(&ln_fs) - (n_samples as f64).ln() + .logsumexp() + - (n_samples as f64).ln() }; // high error tolerance. MC estimation is not the most accurate... assert::close(ln_pp, mc_est, 1e-2); @@ -296,7 +293,7 @@ mod test { #[test] fn ln_pp_single_vs_monte_carlo() { - use crate::misc::logsumexp; + use crate::misc::LogSumExp; let n_samples = 1_000_000; let x: f64 = -0.3; @@ -307,12 +304,11 @@ mod test { nix.ln_pp(&x, &DataOrSuffStat::::from(&vec![])); let mc_est = { - let ln_fs: Vec = nix - .sample_stream(&mut rand::thread_rng()) + nix.sample_stream(&mut rand::thread_rng()) .take(n_samples) .map(|gauss: Gaussian| gauss.ln_f(&x)) - .collect(); - logsumexp(&ln_fs) - (n_samples as f64).ln() + .logsumexp() + - (n_samples as f64).ln() }; // high error tolerance. MC estimation is not the most accurate... assert::close(ln_pp, mc_est, 1e-2); diff --git a/src/dist/normal_inv_gamma/gaussian_prior.rs b/src/dist/normal_inv_gamma/gaussian_prior.rs index 431d348..b4498bf 100644 --- a/src/dist/normal_inv_gamma/gaussian_prior.rs +++ b/src/dist/normal_inv_gamma/gaussian_prior.rs @@ -220,7 +220,7 @@ mod test { #[test] fn ln_m_vs_monte_carlo() { - use crate::misc::logsumexp; + use crate::misc::LogSumExp; let n_samples = 1_000_000; let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; @@ -231,14 +231,13 @@ mod test { // let ln_m = alternate_ln_marginal(&xs, m, v, a, b); let mc_est = { - let ln_fs: Vec = nig - .sample_stream(&mut rand::thread_rng()) + nig.sample_stream(&mut rand::thread_rng()) .take(n_samples) .map(|gauss: Gaussian| { xs.iter().map(|x| gauss.ln_f(x)).sum::() }) - .collect(); - logsumexp(&ln_fs) - (n_samples as f64).ln() + .logsumexp() + - (n_samples as f64).ln() }; // high error tolerance. MC estimation is not the most accurate... assert::close(ln_m, mc_est, 1e-2); @@ -247,7 +246,7 @@ mod test { #[test] fn ln_m_vs_importance() { use crate::dist::Gamma; - use crate::misc::logsumexp; + use crate::misc::LogSumExp; let n_samples = 1_000_000; let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; @@ -260,16 +259,14 @@ mod test { let mut rng = rand::thread_rng(); let pr_m = Gaussian::new(1.0, 8.0).unwrap(); let pr_s = Gamma::new(2.0, 0.4).unwrap(); - let ln_fs: Vec = (0..n_samples) - .map(|_| { - let mu: f64 = pr_m.draw(&mut rng); - let var: f64 = pr_s.draw(&mut rng); - let gauss = Gaussian::new(mu, var.sqrt()).unwrap(); - let ln_f = xs.iter().map(|x| gauss.ln_f(x)).sum::(); - ln_f + nig.ln_f(&gauss) - pr_m.ln_f(&mu) - pr_s.ln_f(&var) - }) - .collect(); - logsumexp(&ln_fs) - (n_samples as f64).ln() + let ln_fs = (0..n_samples).map(|_| { + let mu: f64 = pr_m.draw(&mut rng); + let var: f64 = pr_s.draw(&mut rng); + let gauss = Gaussian::new(mu, var.sqrt()).unwrap(); + let ln_f = xs.iter().map(|x| gauss.ln_f(x)).sum::(); + ln_f + nig.ln_f(&gauss) - pr_m.ln_f(&mu) - pr_s.ln_f(&var) + }); + ln_fs.logsumexp() - (n_samples as f64).ln() }; // high error tolerance. MC estimation is not the most accurate... assert::close(ln_m, mc_est, 1e-2); @@ -277,7 +274,7 @@ mod test { #[test] fn ln_pp_vs_monte_carlo() { - use crate::misc::logsumexp; + use crate::misc::LogSumExp; let n_samples = 1_000_000; let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; @@ -290,12 +287,11 @@ mod test { // let ln_m = alternate_ln_marginal(&xs, m, v, a, b); let mc_est = { - let ln_fs: Vec = post - .sample_stream(&mut rand::thread_rng()) + post.sample_stream(&mut rand::thread_rng()) .take(n_samples) .map(|gauss: Gaussian| gauss.ln_f(&y)) - .collect(); - logsumexp(&ln_fs) - (n_samples as f64).ln() + .logsumexp() + - (n_samples as f64).ln() }; // high error tolerance. MC estimation is not the most accurate... assert::close(ln_pp, mc_est, 1e-2); diff --git a/src/experimental/stick_breaking_process/stick_breaking.rs b/src/experimental/stick_breaking_process/stick_breaking.rs index 4f47b8c..b6cf691 100644 --- a/src/experimental/stick_breaking_process/stick_breaking.rs +++ b/src/experimental/stick_breaking_process/stick_breaking.rs @@ -391,8 +391,6 @@ mod tests { #[test] fn sb_ln_m_vs_monte_carlo() { - use crate::misc::logsumexp; - let n_samples = 1_000_000; let xs: Vec = vec![1, 2, 3]; @@ -408,7 +406,7 @@ mod tests { xs.iter().map(|x| sbd.ln_f(x)).sum::() }) .collect(); - logsumexp(&ln_fs) - (n_samples as f64).ln() + ln_fs.logsumexp() - (n_samples as f64).ln() }; // high error tolerance. MC estimation is not the most accurate... assert::close(ln_m, mc_est, 1e-2); diff --git a/src/misc/func.rs b/src/misc/func.rs index beeb9ba..484ce31 100644 --- a/src/misc/func.rs +++ b/src/misc/func.rs @@ -91,20 +91,51 @@ pub fn ln_gammafn(x: f64) -> f64 { Gamma::ln_gamma(x).0 } -/// Safely compute `log(sum(exp(xs))` +pub fn log1pexp(x: f64) -> f64 { + if x <= -37.0 { + f64::exp(x) + } else if x <= 18.0 { + f64::ln_1p(f64::exp(x)) + } else if x <= 33.3 { + x + f64::exp(-x) + } else { + x + } +} + +pub fn logaddexp(x: f64, y: f64) -> f64 { + if x > y { + x + log1pexp(y - x) + } else { + y + log1pexp(x - y) + } +} + /// Streaming `logexp` implementation as described in [Sebastian Nowozin's blog](https://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html) -pub fn logsumexp(xs: impl Iterator) -> f64 { - let (alpha, r) = xs.fold((f64::NEG_INFINITY, 0.0), |(alpha, r), x| { - if x == f64::NEG_INFINITY { - (alpha, r) - } else if x <= alpha { - (alpha, r + (x - alpha).exp()) - } else { - (x, r.mul_add((alpha - x).exp(), 1.0)) - } - }); +pub trait LogSumExp { + fn logsumexp(self) -> f64; +} + +use std::borrow::Borrow; - r.ln() + alpha +impl LogSumExp for I +where + I: Iterator, + I::Item: std::borrow::Borrow, +{ + fn logsumexp(self) -> f64 { + let (max, sum) = + self.fold((f64::NEG_INFINITY, 0.0), |(max, sum), x| { + let x = *x.borrow(); + if x > max { + (x, sum * (max - x).exp()) + } else { + (max, sum + (x - max).exp()) + } + }); + + max + sum.ln() + } } /// Cumulative sum of `xs` @@ -266,7 +297,7 @@ pub fn ln_pflips( let z = if normed { 0.0 } else { - logsumexp(ln_weights.iter().map(|&x| x)) + ln_weights.iter().copied().logsumexp() }; // doing this instead of calling pflips shaves about 30% off the runtime. @@ -651,7 +682,7 @@ const LN_FACT: [f64; 255] = [ 921.837_328_707_804_9, 927.193_914_982_476_7, 932.555_207_148_186_2, - 937.921_183_163_208_1, + 937.921_821_191_335_7, 943.291_821_191_335_7, 948.667_099_599_019_8, 954.046_996_952_560_4, @@ -746,7 +777,16 @@ pub fn log_product(data: impl Iterator) -> f64 { #[cfg(test)] mod tests { use super::*; + use proptest::prelude::*; + proptest! { + #[test] + fn test_log1pexp_close_to_ln_1p_exp(x in -100.0..100.0_f64) { + let expected = (1.0 + x.exp()).ln(); + let actual = log1pexp(x); + prop_assert!((expected - actual).abs() < 1e-10); + } + } #[test] fn test_log_product_empty() { let empty: Vec = vec![]; @@ -794,7 +834,6 @@ mod tests { assert_eq!(log_product(with_zero.into_iter()), f64::NEG_INFINITY); } - use super::*; use crate::prelude::ChiSquared; use crate::traits::Cdf; use rand::thread_rng; @@ -825,12 +864,10 @@ mod tests { assert_eq!(argmax(&xs), vec![4, 6]); } - use proptest::prelude::*; - proptest! { #[test] - fn proptest_logsumexp(xs in prop::collection::vec(-1e10f64..1e10, 0..100)) { - let result = logsumexp(xs.iter().cloned()); + fn proptest_logsumexp(xs in prop::collection::vec(-1e10_f64..1e10_f64, 0..100)) { + let result = xs.iter().logsumexp(); if xs.is_empty() { prop_assert!(result.is_nan()); @@ -854,18 +891,18 @@ mod tests { #[test] fn proptest_logsumexp_with_neg_infinity( - xs in prop::collection::vec(-1e10f64..1e10, 0..99), - neg_inf_count in 0..10usize + xs in prop::collection::vec(-1e10_f64..1e10_f64, 0..99), + neg_inf_count in 0..10_usize ) { let mut extended_xs = xs.clone(); extended_xs.extend(std::iter::repeat(f64::NEG_INFINITY).take(neg_inf_count)); - let result = logsumexp(extended_xs.iter().cloned()); + let result = extended_xs.iter().logsumexp(); if extended_xs.iter().all(|&x| x == f64::NEG_INFINITY) { prop_assert!(result == f64::NEG_INFINITY); } else { - let expected = logsumexp(xs.iter().cloned()); + let expected = xs.iter().logsumexp(); prop_assert!((result - expected).abs() < 1e-10); } } diff --git a/tests/mi.rs b/tests/mi.rs index 985e935..514aa78 100644 --- a/tests/mi.rs +++ b/tests/mi.rs @@ -4,6 +4,8 @@ use rv::traits::*; #[test] fn bivariate_mixture_mi() { + use rv::misc::LogSumExp; + let n_samples = 100_000; let n_f = n_samples as f64; @@ -56,15 +58,13 @@ fn bivariate_mixture_mi() { let logpy = my.ln_f(&y); let logpxy = { - let ps: Vec = (0..k) - .map(|ix| { - let px = mx.components()[ix].ln_f(&x); - let py = my.components()[ix].ln_f(&y); - px + py - lnk - }) - .collect(); - - rv::misc::logsumexp(&ps) + let ps = (0..k).map(|ix| { + let px = mx.components()[ix].ln_f(&x); + let py = my.components()[ix].ln_f(&y); + px + py - lnk + }); + + ps.logsumexp() }; (mi + logpxy - logpx - logpy, hxy - logpxy) From cedb4b2decb61d6c140d7689a6f93a78a588a694 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Wed, 31 Jul 2024 19:25:07 -0700 Subject: [PATCH 3/9] moar --- proptest-regressions/misc/func.txt | 7 +++++++ .../stick_breaking_process/stick_breaking.rs | 9 +++++---- src/misc/mod.rs | 2 +- 3 files changed, 13 insertions(+), 5 deletions(-) create mode 100644 proptest-regressions/misc/func.txt diff --git a/proptest-regressions/misc/func.txt b/proptest-regressions/misc/func.txt new file mode 100644 index 0000000..fd26ebb --- /dev/null +++ b/proptest-regressions/misc/func.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc b2c0853abfa53c6717adb519ac8b2871674f6977f57441e1f48d71c1c625c4f0 # shrinks to xs = [] diff --git a/src/experimental/stick_breaking_process/stick_breaking.rs b/src/experimental/stick_breaking_process/stick_breaking.rs index b6cf691..52fa680 100644 --- a/src/experimental/stick_breaking_process/stick_breaking.rs +++ b/src/experimental/stick_breaking_process/stick_breaking.rs @@ -391,6 +391,8 @@ mod tests { #[test] fn sb_ln_m_vs_monte_carlo() { + use crate::misc::func::LogSumExp; + let n_samples = 1_000_000; let xs: Vec = vec![1, 2, 3]; @@ -399,14 +401,13 @@ mod tests { let ln_m = sb.ln_m(&obs); let mc_est = { - let ln_fs: Vec = sb - .sample_stream(&mut rand::thread_rng()) + sb.sample_stream(&mut rand::thread_rng()) .take(n_samples) .map(|sbd: StickBreakingDiscrete| { xs.iter().map(|x| sbd.ln_f(x)).sum::() }) - .collect(); - ln_fs.logsumexp() - (n_samples as f64).ln() + .logsumexp() + - (n_samples as f64).ln() }; // high error tolerance. MC estimation is not the most accurate... assert::close(ln_m, mc_est, 1e-2); diff --git a/src/misc/mod.rs b/src/misc/mod.rs index 14d29b4..0ac7790 100644 --- a/src/misc/mod.rs +++ b/src/misc/mod.rs @@ -2,7 +2,7 @@ pub mod bessel; mod convergent_seq; pub(crate) mod entropy; -mod func; +pub mod func; mod ks; mod legendre; #[cfg(feature = "arraydist")] From 8d548afebb255eae6c8357c723b73f73ed128d4c Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Wed, 31 Jul 2024 19:47:24 -0700 Subject: [PATCH 4/9] oops --- src/misc/func.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/misc/func.rs b/src/misc/func.rs index 484ce31..6e10b1e 100644 --- a/src/misc/func.rs +++ b/src/misc/func.rs @@ -654,7 +654,7 @@ const LN_FACT: [f64; 255] = [ 773.860_102_952_558_5, 779.075_038_710_167_4, 784.295_394_535_245_7, - 789.521_141_208_959, + 789.921_183_163_208_1, 794.752_249_825_813_5, 799.988_691_788_643_5, 805.230_438_803_703_1, @@ -868,9 +868,10 @@ mod tests { #[test] fn proptest_logsumexp(xs in prop::collection::vec(-1e10_f64..1e10_f64, 0..100)) { let result = xs.iter().logsumexp(); - + println!("xs: {:?}", xs); + println!("result: {}", result); if xs.is_empty() { - prop_assert!(result.is_nan()); + prop_assert!(result == f64::NEG_INFINITY); } else { // Naive implementation for comparison let max_x = xs.iter().cloned().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(); From a1cd33c6a08d76b0d78ae3d6344e089138adfe26 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Wed, 31 Jul 2024 19:57:29 -0700 Subject: [PATCH 5/9] fix logsumexp --- src/misc/func.rs | 31 ++++++------------------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/src/misc/func.rs b/src/misc/func.rs index 6e10b1e..a5eb242 100644 --- a/src/misc/func.rs +++ b/src/misc/func.rs @@ -124,17 +124,17 @@ where I::Item: std::borrow::Borrow, { fn logsumexp(self) -> f64 { - let (max, sum) = - self.fold((f64::NEG_INFINITY, 0.0), |(max, sum), x| { + let (alpha, r) = + self.fold((f64::NEG_INFINITY, 0.0), |(alpha, r), x| { let x = *x.borrow(); - if x > max { - (x, sum * (max - x).exp()) + if x <= alpha { + (alpha, r + (x - alpha).exp()) } else { - (max, sum + (x - max).exp()) + (x, (alpha - x).exp().mul_add(r, 1.0)) } }); - max + sum.ln() + alpha + r.ln() } } @@ -868,8 +868,6 @@ mod tests { #[test] fn proptest_logsumexp(xs in prop::collection::vec(-1e10_f64..1e10_f64, 0..100)) { let result = xs.iter().logsumexp(); - println!("xs: {:?}", xs); - println!("result: {}", result); if xs.is_empty() { prop_assert!(result == f64::NEG_INFINITY); } else { @@ -890,23 +888,6 @@ mod tests { } } - #[test] - fn proptest_logsumexp_with_neg_infinity( - xs in prop::collection::vec(-1e10_f64..1e10_f64, 0..99), - neg_inf_count in 0..10_usize - ) { - let mut extended_xs = xs.clone(); - extended_xs.extend(std::iter::repeat(f64::NEG_INFINITY).take(neg_inf_count)); - - let result = extended_xs.iter().logsumexp(); - - if extended_xs.iter().all(|&x| x == f64::NEG_INFINITY) { - prop_assert!(result == f64::NEG_INFINITY); - } else { - let expected = xs.iter().logsumexp(); - prop_assert!((result - expected).abs() < 1e-10); - } - } } #[test] From c35692bf0d25b6f4c0f95f6d31f1c8d219c326be Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Wed, 31 Jul 2024 20:04:40 -0700 Subject: [PATCH 6/9] oops --- src/misc/func.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/misc/func.rs b/src/misc/func.rs index a5eb242..8c0c5f1 100644 --- a/src/misc/func.rs +++ b/src/misc/func.rs @@ -654,7 +654,7 @@ const LN_FACT: [f64; 255] = [ 773.860_102_952_558_5, 779.075_038_710_167_4, 784.295_394_535_245_7, - 789.921_183_163_208_1, + 789.521_141_208_959, 794.752_249_825_813_5, 799.988_691_788_643_5, 805.230_438_803_703_1, @@ -682,7 +682,7 @@ const LN_FACT: [f64; 255] = [ 921.837_328_707_804_9, 927.193_914_982_476_7, 932.555_207_148_186_2, - 937.921_821_191_335_7, + 937.921_183_163_208_1, 943.291_821_191_335_7, 948.667_099_599_019_8, 954.046_996_952_560_4, From 7097d5a28a9a8478ffd23c6eff545c66845538db Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Thu, 1 Aug 2024 06:12:30 -0700 Subject: [PATCH 7/9] bump version, add changelog --- CHANGELOG.md | 10 +++++ Cargo.lock | 120 +++++++++++++++++++++++++++++++++------------------ Cargo.toml | 2 +- 3 files changed, 90 insertions(+), 42 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 12d5fd9..630cd88 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## [0.18.0] - 2024-06-24 + +### Added +- Add log1pexp and logaddexp +- Add LogSumExp trait with logsumexp method. This way we can make applying it a little more generic, similar to how sum works. +- Propagate these functions across crate + +### Removed +- Removed logsumexp function taking a slice argument + ## [0.17.0] - 2024-06-24 ### Added diff --git a/Cargo.lock b/Cargo.lock index 3a08b42..1aed388 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,7 +11,7 @@ dependencies = [ "cfg-if", "once_cell", "version_check", - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -37,9 +37,9 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstyle" -version = "1.0.7" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" +checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" [[package]] name = "anyhow" @@ -129,9 +129,15 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.16.1" +version = "1.16.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "102087e286b4677862ea56cf8fc58bb2cdfa8725c40ffb80fe3a008eb7f2fc83" + +[[package]] +name = "byteorder" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "cast" @@ -174,18 +180,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.9" +version = "4.5.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64acc1846d54c1fe936a78dc189c34e28d3f5afc348403f28ecf53660b9b8462" +checksum = "0fbb260a053428790f3de475e304ff84cdbc4face759ea7a3e64c1edd938a7fc" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.9" +version = "4.5.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fb8393d67ba2e7bfaf28a23458e4e2b543cc73a99595511eb207fdb8aede942" +checksum = "64b17d7ea74e9f833c7dbf2cbe4fb12ff26783eda4782a8975b72f895c9b4d99" dependencies = [ "anstyle", "clap_lex", @@ -193,9 +199,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70" +checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" [[package]] name = "criterion" @@ -352,9 +358,9 @@ checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" [[package]] name = "indexmap" -version = "2.2.6" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +checksum = "de3fc2e30ba82dd1b3911c8de1ffc143c74a914a14e99514d7637e3099df5ea0" dependencies = [ "equivalent", "hashbrown", @@ -419,6 +425,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "lambert_w" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "416b1315eb2fad6ac298e11cd12b6bf3d829707a44260cb7a2c265f156be8c4c" + [[package]] name = "lazy_static" version = "1.5.0" @@ -451,18 +463,18 @@ checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "lru" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3262e75e648fce39813cb56ac41f3c3e3f65217ebf3844d818d1f9398cfb0dc" +checksum = "37ee39891760e7d94734f6f63fedc29a2e4a152f836120753a72503f09fcf904" dependencies = [ "hashbrown", ] [[package]] name = "matrixmultiply" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" dependencies = [ "autocfg", "num_cpus", @@ -502,7 +514,7 @@ checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.72", ] [[package]] @@ -616,11 +628,12 @@ checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] name = "peroxide" -version = "0.37.7" +version = "0.37.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a750c02bdc6548cf4ada978d670da80d274f491b514a3760efe88fb9ca1b5dc" +checksum = "0a89bd5e5373da993f5bf84c8b4441352c96b7260ff670012bb4c0c61022f6d7" dependencies = [ "anyhow", + "lambert_w", "matrixmultiply", "order-stat", "paste", @@ -677,9 +690,12 @@ dependencies = [ [[package]] name = "ppv-lite86" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +checksum = "dee4364d9f3b902ef14fab8a1ddffb783a1cb6b4bba3bfc1fa3922732c7de97f" +dependencies = [ + "zerocopy 0.6.6", +] [[package]] name = "proc-macro2" @@ -872,7 +888,7 @@ dependencies = [ [[package]] name = "rv" -version = "0.17.1" +version = "0.18.0" dependencies = [ "approx", "argmin", @@ -938,16 +954,17 @@ checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.72", ] [[package]] name = "serde_json" -version = "1.0.120" +version = "1.0.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" +checksum = "4ab380d7d9f22ef3f21ad3e6c1ebe8e4fc7a2000ccba2e4d71fc96f15b2cb609" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] @@ -1000,9 +1017,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.70" +version = "2.0.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f0209b68b3613b093e0ec905354eccaedcfe83b8cb37cbdeae64026c3064c16" +checksum = "dc4b9b9bf2add8093d3f2c0204471e951b2285580335de42f9d2534f3ae7a8af" dependencies = [ "proc-macro2", "quote", @@ -1023,22 +1040,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.61" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" +checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.61" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" +checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.72", ] [[package]] @@ -1086,9 +1103,9 @@ checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" [[package]] name = "version_check" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "wait-timeout" @@ -1136,7 +1153,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.72", "wasm-bindgen-shared", ] @@ -1158,7 +1175,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.72", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -1181,9 +1198,9 @@ dependencies = [ [[package]] name = "wide" -version = "0.7.25" +version = "0.7.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2caba658a80831539b30698ae9862a72db6697dfdd7151e46920f5f2755c3ce2" +checksum = "901e8597c777fa042e9e245bd56c0dc4418c5db3f845b6ff94fbac732c6a0692" dependencies = [ "bytemuck", "safe_arch", @@ -1271,13 +1288,34 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "zerocopy" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "854e949ac82d619ee9a14c66a1b674ac730422372ccb759ce0c39cabcf2bf8e6" +dependencies = [ + "byteorder", + "zerocopy-derive 0.6.6", +] + [[package]] name = "zerocopy" version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ - "zerocopy-derive", + "zerocopy-derive 0.7.35", +] + +[[package]] +name = "zerocopy-derive" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "125139de3f6b9d625c39e2efdd73d41bdac468ccd556556440e322be0e1bbd91" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", ] [[package]] @@ -1288,5 +1326,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.72", ] diff --git a/Cargo.toml b/Cargo.toml index a8ee59f..e1adc38 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rv" -version = "0.17.1" +version = "0.18.0" authors = ["Baxter Eaves", "Michael Schmidt", "Chad Scherrer"] description = "Random variables" repository = "https://github.com/promised-ai/rv" From a904d2f94e2e3b24471977bbc42fcc72d85ad691 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Thu, 1 Aug 2024 08:06:12 -0700 Subject: [PATCH 8/9] release change link --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 630cd88..6b2ce8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -215,6 +215,7 @@ - Remove dependency on `quadrature` crate in favor of hand-rolled adaptive Simpson's rule, which handles multimodal distributions better. +[0.18.0]: https://github.com/promise-ai/rv/compare/v0.17.0...v0.18.0 [0.17.0]: https://github.com/promise-ai/rv/compare/v0.16.5...v0.17.0 [0.16.5]: https://github.com/promise-ai/rv/compare/v0.16.4...v0.16.5 [0.16.4]: https://github.com/promise-ai/rv/compare/v0.16.3...v0.16.4 From 061fb336df25c7f255797de7c506dad5ee099a39 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Thu, 1 Aug 2024 08:46:28 -0700 Subject: [PATCH 9/9] minor update --- src/misc/func.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/misc/func.rs b/src/misc/func.rs index 8c0c5f1..c7b73f0 100644 --- a/src/misc/func.rs +++ b/src/misc/func.rs @@ -297,7 +297,7 @@ pub fn ln_pflips( let z = if normed { 0.0 } else { - ln_weights.iter().copied().logsumexp() + ln_weights.iter().logsumexp() }; // doing this instead of calling pflips shaves about 30% off the runtime.