Skip to content

Commit

Permalink
move sorted_uniforms to func.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
cscherrer committed Apr 15, 2024
1 parent 339db73 commit 9fa09ea
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 94 deletions.
95 changes: 1 addition & 94 deletions src/experimental/stick_breaking_process/sbd.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::StickSequence;
use crate::dist::Mixture;
use crate::misc::sorted_uniforms;
use crate::misc::ConvergentSequence;
use crate::traits::*;
use rand::seq::SliceRandom;
Expand Down Expand Up @@ -198,54 +199,6 @@ impl Mode<usize> for StickBreakingDiscrete {
}
}

/// Generate a vector of sorted uniform random variables.
///
/// # Arguments
///
/// * `n` - The number of random variables to generate.
///
/// * `rng` - A mutable reference to the random number generator.
///
/// # Returns
///
/// A vector of sorted uniform random variables.
///
/// # Example
///
/// ```
/// use rand::thread_rng;
/// use rv::experimental::stick_breaking_process::sbd::sorted_uniforms;
///
/// let mut rng = thread_rng();
/// let n = 10000;
/// let xs = sorted_uniforms(n, &mut rng);
/// assert_eq!(xs.len(), n);
///
/// // Result is sorted and in the unit interval
/// assert!(xs.first().map_or(false, |&first| first > 0.0));
/// assert!(xs.last().map_or(false, |&last| last < 1.0));
/// assert!(xs.windows(2).all(|w| w[0] <= w[1]));
///
/// // Mean is approximately 1/2
/// let mean = xs.iter().sum::<f64>() / n as f64;
/// assert!(mean > 0.49 && mean < 0.51);
///
/// // Variance is approximately 1/12
/// let var = xs.iter().map(|x| (x - 0.5).powi(2)).sum::<f64>() / n as f64;
/// assert!(var > 0.08 && var < 0.09);
/// ```
pub fn sorted_uniforms<R: Rng>(n: usize, rng: &mut R) -> Vec<f64> {
let mut xs: Vec<_> = (0..n)
.map(|_| -rng.gen::<f64>().ln())
.scan(0.0, |state, x| {
*state += x;
Some(*state)
})
.collect();
let max = *xs.last().unwrap() - rng.gen::<f64>().ln();
(0..n).for_each(|i| xs[i] /= max);
xs
}
/// Provides density and log-density functions for StickBreakingDiscrete.
impl HasDensity<usize> for StickBreakingDiscrete {
/// Computes the density of a given stick index.
Expand Down Expand Up @@ -353,52 +306,6 @@ mod tests {
use crate::prelude::*;
use rand::thread_rng;

#[test]
fn test_sorted_uniforms() {
let mut rng = thread_rng();
let n = 1000;
let xs = sorted_uniforms(n, &mut rng);
assert_eq!(xs.len(), n);

// Result is sorted and in the unit interval
assert!(&0.0 < xs.first().unwrap());
assert!(xs.last().unwrap() < &1.0);
assert!(xs.windows(2).all(|w| w[0] <= w[1]));

// t will aggregate our chi-squared test statistic
let mut t = 0.0;

{
// We'll build a histogram and count the bin populations, aggregating
// the chi-squared statistic as we go
let mut next_bin = 0.01;
let mut bin_pop = 0;

for x in xs.iter() {
bin_pop += 1;
if *x > next_bin {
let obs = bin_pop as f64;
let exp = n as f64 / 100.0;
t += (obs - exp).powi(2) / exp;
bin_pop = 0;
next_bin += 0.01;
}
}

// The last bin
let obs = bin_pop as f64;
let exp = n as f64 / 100.0;
t += (obs - exp).powi(2) / exp;
}

let alpha = 0.001;

// dof = number of bins minus one
let chi2 = ChiSquared::new(99.0).unwrap();
let p = chi2.sf(&t);
assert!(p > alpha);
}

#[test]
fn test_multi_invccdf_sorted() {
let sticks = StickSequence::new(UnitPowerLaw::new(10.0).unwrap(), None);
Expand Down
98 changes: 98 additions & 0 deletions src/misc/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,55 @@ pub fn ln_fact(n: usize) -> f64 {
}
}

/// Generate a vector of sorted uniform random variables.
///
/// # Arguments
///
/// * `n` - The number of random variables to generate.
///
/// * `rng` - A mutable reference to the random number generator.
///
/// # Returns
///
/// A vector of sorted uniform random variables.
///
/// # Example
///
/// ```
/// use rand::thread_rng;
/// use rv::experimental::stick_breaking_process::sbd::sorted_uniforms;
///
/// let mut rng = thread_rng();
/// let n = 10000;
/// let xs = sorted_uniforms(n, &mut rng);
/// assert_eq!(xs.len(), n);
///
/// // Result is sorted and in the unit interval
/// assert!(xs.first().map_or(false, |&first| first > 0.0));
/// assert!(xs.last().map_or(false, |&last| last < 1.0));
/// assert!(xs.windows(2).all(|w| w[0] <= w[1]));
///
/// // Mean is approximately 1/2
/// let mean = xs.iter().sum::<f64>() / n as f64;
/// assert!(mean > 0.49 && mean < 0.51);
///
/// // Variance is approximately 1/12
/// let var = xs.iter().map(|x| (x - 0.5).powi(2)).sum::<f64>() / n as f64;
/// assert!(var > 0.08 && var < 0.09);
/// ```
pub fn sorted_uniforms<R: Rng>(n: usize, rng: &mut R) -> Vec<f64> {
let mut xs: Vec<_> = (0..n)
.map(|_| -rng.gen::<f64>().ln())
.scan(0.0, |state, x| {
*state += x;
Some(*state)
})
.collect();
let max = *xs.last().unwrap() - rng.gen::<f64>().ln();
(0..n).for_each(|i| xs[i] /= max);
xs
}

const LN_FACT: [f64; 255] = [
0.000_000_000_000_000,
0.000_000_000_000_000,
Expand Down Expand Up @@ -605,6 +654,9 @@ const LN_FACT: [f64; 255] = [
#[cfg(test)]
mod tests {
use super::*;
use crate::prelude::ChiSquared;
use crate::traits::Cdf;
use rand::thread_rng;

const TOL: f64 = 1E-12;

Expand Down Expand Up @@ -737,4 +789,50 @@ mod tests {
assert_eq!(one_count, 0);
assert!(two_count > 30);
}

#[test]
fn test_sorted_uniforms() {
let mut rng = thread_rng();
let n = 1000;
let xs = sorted_uniforms(n, &mut rng);
assert_eq!(xs.len(), n);

// Result is sorted and in the unit interval
assert!(&0.0 < xs.first().unwrap());
assert!(xs.last().unwrap() < &1.0);
assert!(xs.windows(2).all(|w| w[0] <= w[1]));

// t will aggregate our chi-squared test statistic
let mut t = 0.0;

{
// We'll build a histogram and count the bin populations, aggregating
// the chi-squared statistic as we go
let mut next_bin = 0.01;
let mut bin_pop = 0;

for x in xs.iter() {
bin_pop += 1;
if *x > next_bin {
let obs = bin_pop as f64;
let exp = n as f64 / 100.0;
t += (obs - exp).powi(2) / exp;
bin_pop = 0;
next_bin += 0.01;
}
}

// The last bin
let obs = bin_pop as f64;
let exp = n as f64 / 100.0;
t += (obs - exp).powi(2) / exp;
}

let alpha = 0.001;

// dof = number of bins minus one
let chi2 = ChiSquared::new(99.0).unwrap();
let p = chi2.sf(&t);
assert!(p > alpha);
}
}

0 comments on commit 9fa09ea

Please sign in to comment.