From 9fa09ea3fea50fa9da5a3dae6658086ed46cad5d Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 15 Apr 2024 09:43:59 -0700 Subject: [PATCH] move sorted_uniforms to func.rs --- .../stick_breaking_process/sbd.rs | 95 +----------------- src/misc/func.rs | 98 +++++++++++++++++++ 2 files changed, 99 insertions(+), 94 deletions(-) diff --git a/src/experimental/stick_breaking_process/sbd.rs b/src/experimental/stick_breaking_process/sbd.rs index 8f6967c..619c12a 100644 --- a/src/experimental/stick_breaking_process/sbd.rs +++ b/src/experimental/stick_breaking_process/sbd.rs @@ -1,5 +1,6 @@ use super::StickSequence; use crate::dist::Mixture; +use crate::misc::sorted_uniforms; use crate::misc::ConvergentSequence; use crate::traits::*; use rand::seq::SliceRandom; @@ -198,54 +199,6 @@ impl Mode for StickBreakingDiscrete { } } -/// Generate a vector of sorted uniform random variables. -/// -/// # Arguments -/// -/// * `n` - The number of random variables to generate. -/// -/// * `rng` - A mutable reference to the random number generator. -/// -/// # Returns -/// -/// A vector of sorted uniform random variables. -/// -/// # Example -/// -/// ``` -/// use rand::thread_rng; -/// use rv::experimental::stick_breaking_process::sbd::sorted_uniforms; -/// -/// let mut rng = thread_rng(); -/// let n = 10000; -/// let xs = sorted_uniforms(n, &mut rng); -/// assert_eq!(xs.len(), n); -/// -/// // Result is sorted and in the unit interval -/// assert!(xs.first().map_or(false, |&first| first > 0.0)); -/// assert!(xs.last().map_or(false, |&last| last < 1.0)); -/// assert!(xs.windows(2).all(|w| w[0] <= w[1])); -/// -/// // Mean is approximately 1/2 -/// let mean = xs.iter().sum::() / n as f64; -/// assert!(mean > 0.49 && mean < 0.51); -/// -/// // Variance is approximately 1/12 -/// let var = xs.iter().map(|x| (x - 0.5).powi(2)).sum::() / n as f64; -/// assert!(var > 0.08 && var < 0.09); -/// ``` -pub fn sorted_uniforms(n: usize, rng: &mut R) -> Vec { - let mut xs: Vec<_> = (0..n) - .map(|_| -rng.gen::().ln()) - .scan(0.0, |state, x| { - *state += x; - Some(*state) - }) - .collect(); - let max = *xs.last().unwrap() - rng.gen::().ln(); - (0..n).for_each(|i| xs[i] /= max); - xs -} /// Provides density and log-density functions for StickBreakingDiscrete. impl HasDensity for StickBreakingDiscrete { /// Computes the density of a given stick index. @@ -353,52 +306,6 @@ mod tests { use crate::prelude::*; use rand::thread_rng; - #[test] - fn test_sorted_uniforms() { - let mut rng = thread_rng(); - let n = 1000; - let xs = sorted_uniforms(n, &mut rng); - assert_eq!(xs.len(), n); - - // Result is sorted and in the unit interval - assert!(&0.0 < xs.first().unwrap()); - assert!(xs.last().unwrap() < &1.0); - assert!(xs.windows(2).all(|w| w[0] <= w[1])); - - // t will aggregate our chi-squared test statistic - let mut t = 0.0; - - { - // We'll build a histogram and count the bin populations, aggregating - // the chi-squared statistic as we go - let mut next_bin = 0.01; - let mut bin_pop = 0; - - for x in xs.iter() { - bin_pop += 1; - if *x > next_bin { - let obs = bin_pop as f64; - let exp = n as f64 / 100.0; - t += (obs - exp).powi(2) / exp; - bin_pop = 0; - next_bin += 0.01; - } - } - - // The last bin - let obs = bin_pop as f64; - let exp = n as f64 / 100.0; - t += (obs - exp).powi(2) / exp; - } - - let alpha = 0.001; - - // dof = number of bins minus one - let chi2 = ChiSquared::new(99.0).unwrap(); - let p = chi2.sf(&t); - assert!(p > alpha); - } - #[test] fn test_multi_invccdf_sorted() { let sticks = StickSequence::new(UnitPowerLaw::new(10.0).unwrap(), None); diff --git a/src/misc/func.rs b/src/misc/func.rs index d684b70..8e54f1d 100644 --- a/src/misc/func.rs +++ b/src/misc/func.rs @@ -344,6 +344,55 @@ pub fn ln_fact(n: usize) -> f64 { } } +/// Generate a vector of sorted uniform random variables. +/// +/// # Arguments +/// +/// * `n` - The number of random variables to generate. +/// +/// * `rng` - A mutable reference to the random number generator. +/// +/// # Returns +/// +/// A vector of sorted uniform random variables. +/// +/// # Example +/// +/// ``` +/// use rand::thread_rng; +/// use rv::experimental::stick_breaking_process::sbd::sorted_uniforms; +/// +/// let mut rng = thread_rng(); +/// let n = 10000; +/// let xs = sorted_uniforms(n, &mut rng); +/// assert_eq!(xs.len(), n); +/// +/// // Result is sorted and in the unit interval +/// assert!(xs.first().map_or(false, |&first| first > 0.0)); +/// assert!(xs.last().map_or(false, |&last| last < 1.0)); +/// assert!(xs.windows(2).all(|w| w[0] <= w[1])); +/// +/// // Mean is approximately 1/2 +/// let mean = xs.iter().sum::() / n as f64; +/// assert!(mean > 0.49 && mean < 0.51); +/// +/// // Variance is approximately 1/12 +/// let var = xs.iter().map(|x| (x - 0.5).powi(2)).sum::() / n as f64; +/// assert!(var > 0.08 && var < 0.09); +/// ``` +pub fn sorted_uniforms(n: usize, rng: &mut R) -> Vec { + let mut xs: Vec<_> = (0..n) + .map(|_| -rng.gen::().ln()) + .scan(0.0, |state, x| { + *state += x; + Some(*state) + }) + .collect(); + let max = *xs.last().unwrap() - rng.gen::().ln(); + (0..n).for_each(|i| xs[i] /= max); + xs +} + const LN_FACT: [f64; 255] = [ 0.000_000_000_000_000, 0.000_000_000_000_000, @@ -605,6 +654,9 @@ const LN_FACT: [f64; 255] = [ #[cfg(test)] mod tests { use super::*; + use crate::prelude::ChiSquared; + use crate::traits::Cdf; + use rand::thread_rng; const TOL: f64 = 1E-12; @@ -737,4 +789,50 @@ mod tests { assert_eq!(one_count, 0); assert!(two_count > 30); } + + #[test] + fn test_sorted_uniforms() { + let mut rng = thread_rng(); + let n = 1000; + let xs = sorted_uniforms(n, &mut rng); + assert_eq!(xs.len(), n); + + // Result is sorted and in the unit interval + assert!(&0.0 < xs.first().unwrap()); + assert!(xs.last().unwrap() < &1.0); + assert!(xs.windows(2).all(|w| w[0] <= w[1])); + + // t will aggregate our chi-squared test statistic + let mut t = 0.0; + + { + // We'll build a histogram and count the bin populations, aggregating + // the chi-squared statistic as we go + let mut next_bin = 0.01; + let mut bin_pop = 0; + + for x in xs.iter() { + bin_pop += 1; + if *x > next_bin { + let obs = bin_pop as f64; + let exp = n as f64 / 100.0; + t += (obs - exp).powi(2) / exp; + bin_pop = 0; + next_bin += 0.01; + } + } + + // The last bin + let obs = bin_pop as f64; + let exp = n as f64 / 100.0; + t += (obs - exp).powi(2) / exp; + } + + let alpha = 0.001; + + // dof = number of bins minus one + let chi2 = ChiSquared::new(99.0).unwrap(); + let p = chi2.sf(&t); + assert!(p > alpha); + } }