From e2092e925148b81740f91963e7aac8846987bda2 Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Tue, 1 Oct 2024 15:27:39 +0200 Subject: [PATCH 1/3] Poisson u64 sampling (#1498) This addresses https://github.com/rust-random/rand/issues/1497 by adding `Distribution` It also solves https://github.com/rust-random/rand/issues/1312 by not allowing `lambda` bigger than `1.844e19` (this also makes them always fit into `u64`) --- CHANGELOG.md | 2 ++ rand_distr/src/poisson.rs | 44 ++++++++++++++++++++++++++------------- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d071e09391..15347e017d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,8 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update. - Add `UniformUsize` and use to make `Uniform` for `usize` portable (#1487) - Remove support for generating `isize` and `usize` values with `Standard`, `Uniform` and `Fill` and usage as a `WeightedAliasIndex` weight (#1487) - Require `Clone` and `AsRef` bound for `SeedableRng::Seed`. (#1491) +- Implement `Distribution` for `Poisson` (#1498) +- Limit the maximal acceptable lambda for `Poisson` to solve (#1312) (#1498) ## [0.9.0-alpha.1] - 2024-03-18 - Add the `Slice::num_choices` method to the Slice distribution (#1402) diff --git a/rand_distr/src/poisson.rs b/rand_distr/src/poisson.rs index 26e7712b2c..759f39cde7 100644 --- a/rand_distr/src/poisson.rs +++ b/rand_distr/src/poisson.rs @@ -23,10 +23,6 @@ use rand::Rng; /// This distribution has density function: /// `f(k) = λ^k * exp(-λ) / k!` for `k >= 0`. /// -/// # Known issues -/// -/// See documentation of [`Poisson::new`]. -/// /// # Plot /// /// The following plot shows the Poisson distribution with various values of `λ`. @@ -40,7 +36,7 @@ use rand::Rng; /// use rand_distr::{Poisson, Distribution}; /// /// let poi = Poisson::new(2.0).unwrap(); -/// let v = poi.sample(&mut rand::thread_rng()); +/// let v: f64 = poi.sample(&mut rand::thread_rng()); /// println!("{} is from a Poisson(2) distribution", v); /// ``` #[derive(Clone, Copy, Debug, PartialEq)] @@ -52,13 +48,13 @@ where /// Error type returned from [`Poisson::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] -// Marked non_exhaustive to allow a new error code in the solution to #1312. -#[non_exhaustive] pub enum Error { /// `lambda <= 0` ShapeTooSmall, /// `lambda = ∞` or `lambda = nan` NonFinite, + /// `lambda` is too large, see [Poisson::MAX_LAMBDA] + ShapeTooLarge, } impl fmt::Display for Error { @@ -66,6 +62,9 @@ impl fmt::Display for Error { f.write_str(match self { Error::ShapeTooSmall => "lambda is not positive in Poisson distribution", Error::NonFinite => "lambda is infinite or nan in Poisson distribution", + Error::ShapeTooLarge => { + "lambda is too large in Poisson distribution, see Poisson::MAX_LAMBDA" + } }) } } @@ -125,14 +124,7 @@ where /// Construct a new `Poisson` with the given shape parameter /// `lambda`. /// - /// # Known issues - /// - /// Although this method should return an [`Error`] on invalid parameters, - /// some (extreme) values of `lambda` are known to return a [`Poisson`] - /// object which hangs when [sampled](Distribution::sample). - /// Large (less extreme) values of `lambda` may result in successful - /// sampling but with reduced precision. - /// See [#1312](https://github.com/rust-random/rand/issues/1312). + /// The maximum allowed lambda is [MAX_LAMBDA](Self::MAX_LAMBDA). pub fn new(lambda: F) -> Result, Error> { if !lambda.is_finite() { return Err(Error::NonFinite); @@ -145,11 +137,25 @@ where let method = if lambda < F::from(12.0).unwrap() { Method::Knuth(KnuthMethod::new(lambda)) } else { + if lambda > F::from(Self::MAX_LAMBDA).unwrap() { + return Err(Error::ShapeTooLarge); + } Method::Rejection(RejectionMethod::new(lambda)) }; Ok(Poisson(method)) } + + /// The maximum supported value of `lambda` + /// + /// This value was selected such that + /// `MAX_LAMBDA + 1e6 * sqrt(MAX_LAMBDA) < 2^64 - 1`, + /// thus ensuring that the probability of sampling a value larger than + /// `u64::MAX` is less than 1e-1000. + /// + /// Applying this limit also solves + /// [#1312](https://github.com/rust-random/rand/issues/1312). + pub const MAX_LAMBDA: f64 = 1.844e19; } impl Distribution for KnuthMethod @@ -232,6 +238,14 @@ where } } +impl Distribution for Poisson { + #[inline] + fn sample(&self, rng: &mut R) -> u64 { + // `as` from float to int saturates + as Distribution>::sample(self, rng) as u64 + } +} + #[cfg(test)] mod test { use super::*; From 66b11eb17bb256bc7461278d38b2671685db532a Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Tue, 1 Oct 2024 14:59:11 +0100 Subject: [PATCH 2/3] =?UTF-8?q?Rename=20gen=5Fiter=20=E2=86=92=20random=5F?= =?UTF-8?q?iter,=20misc..=20(#1500)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This extracts the non-inherent-methods stuff from #1492. --- CHANGELOG.md | 1 + README.md | 31 +++++---- benches/benches/seq_choose.rs | 2 +- rand_core/src/lib.rs | 6 +- rand_distr/src/weighted_alias.rs | 7 +- src/lib.rs | 23 ++++--- src/rng.rs | 108 ++++++++++++++++--------------- src/seq/index.rs | 10 +-- 8 files changed, 99 insertions(+), 89 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 15347e017d..3300b9ad9f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update. - Require `Clone` and `AsRef` bound for `SeedableRng::Seed`. (#1491) - Implement `Distribution` for `Poisson` (#1498) - Limit the maximal acceptable lambda for `Poisson` to solve (#1312) (#1498) +- Rename `Rng::gen_iter` to `random_iter` (#1500) ## [0.9.0-alpha.1] - 2024-03-18 - Add the `Slice::num_choices` method to the Slice distribution (#1402) diff --git a/README.md b/README.md index 18f22a89eb..25341ac2d0 100644 --- a/README.md +++ b/README.md @@ -6,26 +6,31 @@ [![API](https://img.shields.io/badge/api-master-yellow.svg)](https://rust-random.github.io/rand/rand) [![API](https://docs.rs/rand/badge.svg)](https://docs.rs/rand) -A Rust library for random number generation, featuring: +Rand is a Rust library supporting random generators: -- Easy random value generation and usage via the [`Rng`](https://docs.rs/rand/*/rand/trait.Rng.html), - [`SliceRandom`](https://docs.rs/rand/*/rand/seq/trait.SliceRandom.html) and - [`IteratorRandom`](https://docs.rs/rand/*/rand/seq/trait.IteratorRandom.html) traits -- Secure seeding via the [`getrandom` crate](https://crates.io/crates/getrandom) - and fast, convenient generation via [`thread_rng`](https://docs.rs/rand/*/rand/fn.thread_rng.html) -- A modular design built over [`rand_core`](https://crates.io/crates/rand_core) - ([see the book](https://rust-random.github.io/book/crates.html)) +- A standard RNG trait: [`rand_core::RngCore`](https://docs.rs/rand_core/latest/rand_core/trait.RngCore.html) - Fast implementations of the best-in-class [cryptographic](https://rust-random.github.io/book/guide-rngs.html#cryptographically-secure-pseudo-random-number-generators-csprngs) and - [non-cryptographic](https://rust-random.github.io/book/guide-rngs.html#basic-pseudo-random-number-generators-prngs) generators + [non-cryptographic](https://rust-random.github.io/book/guide-rngs.html#basic-pseudo-random-number-generators-prngs) generators: [`rand::rngs`](https://docs.rs/rand/latest/rand/rngs/index.html), and more RNGs: [`rand_chacha`](https://docs.rs/rand_chacha), [`rand_xoshiro`](https://docs.rs/rand_xoshiro/), [`rand_pcg`](https://docs.rs/rand_pcg/), [rngs repo](https://github.com/rust-random/rngs/) +- [`rand::thread_rng`](https://docs.rs/rand/latest/rand/fn.thread_rng.html) is an asymtotically-fast, reasonably secure generator available on all `std` targets +- Secure seeding via the [`getrandom` crate](https://crates.io/crates/getrandom) + +Supporting random value generation and random processes: + +- [`Standard`](https://docs.rs/rand/latest/rand/distributions/struct.Standard.html) random value generation +- Ranged [`Uniform`](https://docs.rs/rand/latest/rand/distributions/struct.Uniform.html) number generation for many types - A flexible [`distributions`](https://docs.rs/rand/*/rand/distr/index.html) module - Samplers for a large number of random number distributions via our own [`rand_distr`](https://docs.rs/rand_distr) and via the [`statrs`](https://docs.rs/statrs/0.13.0/statrs/) +- Random processes (mostly choose and shuffle) via [`rand::seq`](https://docs.rs/rand/latest/rand/seq/index.html) traits + +All with: + - [Portably reproducible output](https://rust-random.github.io/book/portability.html) - `#[no_std]` compatibility (partial) - *Many* performance optimisations -It's also worth pointing out what `rand` *is not*: +It's also worth pointing out what Rand *is not*: - Small. Most low-level crates are small, but the higher-level `rand` and `rand_distr` each contain a lot of functionality. @@ -73,8 +78,7 @@ Rand is built with these features enabled by default: - `alloc` (implied by `std`) enables functionality requiring an allocator - `getrandom` (implied by `std`) is an optional dependency providing the code behind `rngs::OsRng` -- `std_rng` enables inclusion of `StdRng`, `thread_rng` and `random` - (the latter two *also* require that `std` be enabled) +- `std_rng` enables inclusion of `StdRng`, `thread_rng` Optionally, the following dependencies can be enabled: @@ -94,8 +98,7 @@ experimental `simd_support` feature. Rand supports limited functionality in `no_std` mode (enabled via `default-features = false`). In this case, `OsRng` and `from_os_rng` are unavailable (unless `getrandom` is enabled), large parts of `seq` are -unavailable (unless `alloc` is enabled), and `thread_rng` and `random` are -unavailable. +unavailable (unless `alloc` is enabled), and `thread_rng` is unavailable. ## Portability and platform support diff --git a/benches/benches/seq_choose.rs b/benches/benches/seq_choose.rs index f418f9cc4d..58c4f894ea 100644 --- a/benches/benches/seq_choose.rs +++ b/benches/benches/seq_choose.rs @@ -19,7 +19,7 @@ criterion_group!( criterion_main!(benches); pub fn bench(c: &mut Criterion) { - c.bench_function("seq_slice_choose_1_of_1000", |b| { + c.bench_function("seq_slice_choose_1_of_100", |b| { let mut rng = Pcg32::from_rng(thread_rng()); let mut buf = [0i32; 100]; rng.fill(&mut buf); diff --git a/rand_core/src/lib.rs b/rand_core/src/lib.rs index 3c16a9767c..39e95d95db 100644 --- a/rand_core/src/lib.rs +++ b/rand_core/src/lib.rs @@ -54,11 +54,11 @@ pub use getrandom; #[cfg(feature = "getrandom")] pub use os::OsRng; -/// The core of a random number generator. +/// Implementation-level interface for RNGs /// /// This trait encapsulates the low-level functionality common to all /// generators, and is the "back end", to be implemented by generators. -/// End users should normally use the `Rng` trait from the [`rand`] crate, +/// End users should normally use the [`rand::Rng`] trait /// which is automatically implemented for every type implementing `RngCore`. /// /// Three different methods for generating random data are provided since the @@ -129,7 +129,7 @@ pub use os::OsRng; /// rand_core::impl_try_rng_from_rng_core!(CountingRng); /// ``` /// -/// [`rand`]: https://docs.rs/rand +/// [`rand::Rng`]: https://docs.rs/rand/latest/rand/trait.Rng.html /// [`fill_bytes`]: RngCore::fill_bytes /// [`next_u32`]: RngCore::next_u32 /// [`next_u64`]: RngCore::next_u64 diff --git a/rand_distr/src/weighted_alias.rs b/rand_distr/src/weighted_alias.rs index 593219cafd..537060f388 100644 --- a/rand_distr/src/weighted_alias.rs +++ b/rand_distr/src/weighted_alias.rs @@ -275,9 +275,10 @@ where } } -/// Trait that must be implemented for weights, that are used with -/// [`WeightedAliasIndex`]. Currently no guarantees on the correctness of -/// [`WeightedAliasIndex`] are given for custom implementations of this trait. +/// Weight bound for [`WeightedAliasIndex`] +/// +/// Currently no guarantees on the correctness of [`WeightedAliasIndex`] are +/// given for custom implementations of this trait. pub trait AliasableWeight: Sized + Copy diff --git a/src/lib.rs b/src/lib.rs index 3abbc5a266..958c15d481 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,25 +14,24 @@ //! //! # Quick Start //! -//! To get you started quickly, the easiest and highest-level way to get -//! a random value is to use [`random()`]; alternatively you can use -//! [`thread_rng()`]. The [`Rng`] trait provides a useful API on all RNGs, while -//! the [`distr`] and [`seq`] modules provide further -//! functionality on top of RNGs. -//! //! ``` +//! // The prelude import enables methods we use below, specifically +//! // Rng::random, Rng::sample, SliceRandom::shuffle and IndexedRandom::choose. //! use rand::prelude::*; //! -//! if rand::random() { // generates a boolean -//! // Try printing a random unicode code point (probably a bad idea)! -//! println!("char: {}", rand::random::()); -//! } -//! +//! // Get an RNG: //! let mut rng = rand::thread_rng(); -//! let y: f64 = rng.random(); // generates a float between 0 and 1 //! +//! // Try printing a random unicode code point (probably a bad idea)! +//! println!("char: '{}'", rng.random::()); +//! // Try printing a random alphanumeric value instead! +//! println!("alpha: '{}'", rng.sample(rand::distr::Alphanumeric) as char); +//! +//! // Generate and shuffle a sequence: //! let mut nums: Vec = (1..100).collect(); //! nums.shuffle(&mut rng); +//! // And take a random pick (yes, we didn't need to shuffle first!): +//! let _ = nums.choose(&mut rng); //! ``` //! //! # The Book diff --git a/src/rng.rs b/src/rng.rs index 9190747a29..7c9e887a2d 100644 --- a/src/rng.rs +++ b/src/rng.rs @@ -15,10 +15,14 @@ use core::num::Wrapping; use core::{mem, slice}; use rand_core::RngCore; -/// An automatically-implemented extension trait on [`RngCore`] providing high-level -/// generic methods for sampling values and other convenience methods. +/// User-level interface for RNGs /// -/// This is the primary trait to use when generating random values. +/// [`RngCore`] is the `dyn`-safe implementation-level interface for Random +/// (Number) Generators. This trait, `Rng`, provides a user-level interface on +/// RNGs. It is implemented automatically for any `R: RngCore`. +/// +/// This trait must usually be brought into scope via `use rand::Rng;` or +/// `use rand::prelude::*;`. /// /// # Generic usage /// @@ -96,55 +100,13 @@ pub trait Rng: RngCore { Standard.sample(self) } - /// Generate a random value in the given range. - /// - /// This function is optimised for the case that only a single sample is - /// made from the given range. See also the [`Uniform`] distribution - /// type which may be faster if sampling from the same range repeatedly. - /// - /// All types support `low..high_exclusive` and `low..=high` range syntax. - /// Unsigned integer types also support `..high_exclusive` and `..=high` syntax. - /// - /// # Panics - /// - /// Panics if the range is empty, or if `high - low` overflows for floats. - /// - /// # Example - /// - /// ``` - /// use rand::{thread_rng, Rng}; - /// - /// let mut rng = thread_rng(); - /// - /// // Exclusive range - /// let n: u32 = rng.gen_range(..10); - /// println!("{}", n); - /// let m: f64 = rng.gen_range(-40.0..1.3e5); - /// println!("{}", m); - /// - /// // Inclusive range - /// let n: u32 = rng.gen_range(..=10); - /// println!("{}", n); - /// ``` - /// - /// [`Uniform`]: distr::uniform::Uniform - #[track_caller] - fn gen_range(&mut self, range: R) -> T - where - T: SampleUniform, - R: SampleRange, - { - assert!(!range.is_empty(), "cannot sample empty range"); - range.sample_single(self).unwrap() - } - - /// Generate values via an iterator + /// Return an iterator over [`random`](Self::random) variates /// /// This is a just a wrapper over [`Rng::sample_iter`] using /// [`distr::Standard`]. /// /// Note: this method consumes its argument. Use - /// `(&mut rng).gen_iter()` to avoid consuming the RNG. + /// `(&mut rng).random_iter()` to avoid consuming the RNG. /// /// # Example /// @@ -152,11 +114,11 @@ pub trait Rng: RngCore { /// use rand::{rngs::mock::StepRng, Rng}; /// /// let rng = StepRng::new(1, 1); - /// let v: Vec = rng.gen_iter().take(5).collect(); + /// let v: Vec = rng.random_iter().take(5).collect(); /// assert_eq!(&v, &[1, 2, 3, 4, 5]); /// ``` #[inline] - fn gen_iter(self) -> distr::DistIter + fn random_iter(self) -> distr::DistIter where Self: Sized, Standard: Distribution, @@ -247,6 +209,48 @@ pub trait Rng: RngCore { dest.fill(self) } + /// Generate a random value in the given range. + /// + /// This function is optimised for the case that only a single sample is + /// made from the given range. See also the [`Uniform`] distribution + /// type which may be faster if sampling from the same range repeatedly. + /// + /// All types support `low..high_exclusive` and `low..=high` range syntax. + /// Unsigned integer types also support `..high_exclusive` and `..=high` syntax. + /// + /// # Panics + /// + /// Panics if the range is empty, or if `high - low` overflows for floats. + /// + /// # Example + /// + /// ``` + /// use rand::{thread_rng, Rng}; + /// + /// let mut rng = thread_rng(); + /// + /// // Exclusive range + /// let n: u32 = rng.gen_range(..10); + /// println!("{}", n); + /// let m: f64 = rng.gen_range(-40.0..1.3e5); + /// println!("{}", m); + /// + /// // Inclusive range + /// let n: u32 = rng.gen_range(..=10); + /// println!("{}", n); + /// ``` + /// + /// [`Uniform`]: distr::uniform::Uniform + #[track_caller] + fn gen_range(&mut self, range: R) -> T + where + T: SampleUniform, + R: SampleRange, + { + assert!(!range.is_empty(), "cannot sample empty range"); + range.sample_single(self).unwrap() + } + /// Return a bool with a probability `p` of being true. /// /// See also the [`Bernoulli`] distribution, which may be faster if @@ -316,7 +320,7 @@ pub trait Rng: RngCore { since = "0.9.0", note = "Renamed to `random` to avoid conflict with the new `gen` keyword in Rust 2024." )] - fn gen(&mut self) -> T + fn r#gen(&mut self) -> T where Standard: Distribution, { @@ -474,8 +478,8 @@ mod test { // Check equivalence for generated floats let mut array = [0f32; 2]; rng.fill(&mut array); - let gen: [f32; 2] = rng.random(); - assert_eq!(array, gen); + let arr2: [f32; 2] = rng.random(); + assert_eq!(array, arr2); } #[test] diff --git a/src/seq/index.rs b/src/seq/index.rs index 5bb1a7597f..e66b503988 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -282,10 +282,12 @@ where } } -/// Randomly sample exactly `amount` distinct indices from `0..length`, and -/// return them in an arbitrary order (there is no guarantee of shuffling or -/// ordering). The weights are to be provided by the input function `weights`, -/// which will be called once for each index. +/// Randomly sample exactly `amount` distinct indices from `0..length` +/// +/// Results are in arbitrary order (there is no guarantee of shuffling or +/// ordering). +/// +/// Function `weight` is called once for each index to provide weights. /// /// This method is used internally by the slice sampling methods, but it can /// sometimes be useful to have the indices themselves so this is provided as From bc3341185ee9fd63e63a7c3266f28478aa2ae5fd Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Thu, 3 Oct 2024 21:18:56 +0200 Subject: [PATCH 3/3] Make sure BTPE is not entered when np < 10 (#1484) --- rand_distr/CHANGELOG.md | 2 + rand_distr/src/binomial.rs | 483 ++++++++++++++++++++----------------- 2 files changed, 270 insertions(+), 215 deletions(-) diff --git a/rand_distr/CHANGELOG.md b/rand_distr/CHANGELOG.md index 51bde39e86..93756eb705 100644 --- a/rand_distr/CHANGELOG.md +++ b/rand_distr/CHANGELOG.md @@ -6,6 +6,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased - The `serde1` feature has been renamed `serde` (#1477) +- Fix panic in Binomial (#1484) +- Move some of the computations in Binomial from `sample` to `new` (#1484) ### Added - Add plots for `rand_distr` distributions to documentation (#1434) diff --git a/rand_distr/src/binomial.rs b/rand_distr/src/binomial.rs index 885d8b21c3..3ee0f447b4 100644 --- a/rand_distr/src/binomial.rs +++ b/rand_distr/src/binomial.rs @@ -26,10 +26,6 @@ use rand::Rng; /// /// `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`. /// -/// # Known issues -/// -/// See documentation of [`Binomial::new`]. -/// /// # Plot /// /// The following plot of the binomial distribution illustrates the @@ -50,10 +46,34 @@ use rand::Rng; #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Binomial { - /// Number of trials. + method: Method, +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +enum Method { + Binv(Binv, bool), + Btpe(Btpe, bool), + Poisson(crate::poisson::KnuthMethod), + Constant(u64), +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +struct Binv { + r: f64, + s: f64, + a: f64, + n: u64, +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +struct Btpe { n: u64, - /// Probability of success. p: f64, + m: i64, + p1: f64, } /// Error type returned from [`Binomial::new`]. @@ -82,13 +102,6 @@ impl std::error::Error for Error {} impl Binomial { /// Construct a new `Binomial` with the given shape parameters `n` (number /// of trials) and `p` (probability of success). - /// - /// # Known issues - /// - /// Although this method should return an [`Error`] on invalid parameters, - /// some (extreme) parameter combinations are known to return a [`Binomial`] - /// object which panics when [sampled](Distribution::sample). - /// See [#1378](https://github.com/rust-random/rand/issues/1378). pub fn new(n: u64, p: f64) -> Result { if !(p >= 0.0) { return Err(Error::ProbabilityTooSmall); @@ -96,33 +109,22 @@ impl Binomial { if !(p <= 1.0) { return Err(Error::ProbabilityTooLarge); } - Ok(Binomial { n, p }) - } -} - -/// Convert a `f64` to an `i64`, panicking on overflow. -fn f64_to_i64(x: f64) -> i64 { - assert!(x < (i64::MAX as f64)); - x as i64 -} -impl Distribution for Binomial { - #[allow(clippy::many_single_char_names)] // Same names as in the reference. - fn sample(&self, rng: &mut R) -> u64 { - // Handle these values directly. - if self.p == 0.0 { - return 0; - } else if self.p == 1.0 { - return self.n; + if p == 0.0 { + return Ok(Binomial { + method: Method::Constant(0), + }); } - // The binomial distribution is symmetrical with respect to p -> 1-p, - // k -> n-k switch p so that it is less than 0.5 - this allows for lower - // expected values we will just invert the result at the end - let p = if self.p <= 0.5 { self.p } else { 1.0 - self.p }; + if p == 1.0 { + return Ok(Binomial { + method: Method::Constant(n), + }); + } - let result; - let q = 1. - p; + // The binomial distribution is symmetrical with respect to p -> 1-p + let flipped = p > 0.5; + let p = if flipped { 1.0 - p } else { p }; // For small n * min(p, 1 - p), the BINV algorithm based on the inverse // transformation of the binomial distribution is efficient. Otherwise, @@ -136,204 +138,253 @@ impl Distribution for Binomial { // Ranlib uses 30, and GSL uses 14. const BINV_THRESHOLD: f64 = 10.; - // Same value as in GSL. - // It is possible for BINV to get stuck, so we break if x > BINV_MAX_X and try again. - // It would be safer to set BINV_MAX_X to self.n, but it is extremely unlikely to be relevant. - // When n*p < 10, so is n*p*q which is the variance, so a result > 110 would be 100 / sqrt(10) = 31 standard deviations away. - const BINV_MAX_X: u64 = 110; - - if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (i32::MAX as u64) { - // Use the BINV algorithm. - let s = p / q; - let a = ((self.n + 1) as f64) * s; - - result = 'outer: loop { - let mut r = q.powi(self.n as i32); - let mut u: f64 = rng.random(); - let mut x = 0; - - while u > r { - u -= r; - x += 1; - if x > BINV_MAX_X { - continue 'outer; - } - r *= a / (x as f64) - s; - } - break x; + let np = n as f64 * p; + let method = if np < BINV_THRESHOLD { + let q = 1.0 - p; + if q == 1.0 { + // p is so small that this is extremely close to a Poisson distribution. + // The flipped case cannot occur here. + Method::Poisson(crate::poisson::KnuthMethod::new(np)) + } else { + let s = p / q; + Method::Binv( + Binv { + r: q.powf(n as f64), + s, + a: (n as f64 + 1.0) * s, + n, + }, + flipped, + ) } } else { - // Use the BTPE algorithm. - - // Threshold for using the squeeze algorithm. This can be freely - // chosen based on performance. Ranlib and GSL use 20. - const SQUEEZE_THRESHOLD: i64 = 20; - - // Step 0: Calculate constants as functions of `n` and `p`. - let n = self.n as f64; - let np = n * p; + let q = 1.0 - p; let npq = np * q; + let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5; let f_m = np + p; let m = f64_to_i64(f_m); - // radius of triangle region, since height=1 also area of region - let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5; - // tip of triangle - let x_m = (m as f64) + 0.5; - // left edge of triangle - let x_l = x_m - p1; - // right edge of triangle - let x_r = x_m + p1; - let c = 0.134 + 20.5 / (15.3 + (m as f64)); - // p1 + area of parallelogram region - let p2 = p1 * (1. + 2. * c); - - fn lambda(a: f64) -> f64 { - a * (1. + 0.5 * a) + Method::Btpe(Btpe { n, p, m, p1 }, flipped) + }; + Ok(Binomial { method }) + } +} + +/// Convert a `f64` to an `i64`, panicking on overflow. +fn f64_to_i64(x: f64) -> i64 { + assert!(x < (i64::MAX as f64)); + x as i64 +} + +fn binv(binv: Binv, flipped: bool, rng: &mut R) -> u64 { + // Same value as in GSL. + // It is possible for BINV to get stuck, so we break if x > BINV_MAX_X and try again. + // It would be safer to set BINV_MAX_X to self.n, but it is extremely unlikely to be relevant. + // When n*p < 10, so is n*p*q which is the variance, so a result > 110 would be 100 / sqrt(10) = 31 standard deviations away. + const BINV_MAX_X: u64 = 110; + + let sample = 'outer: loop { + let mut r = binv.r; + let mut u: f64 = rng.random(); + let mut x = 0; + + while u > r { + u -= r; + x += 1; + if x > BINV_MAX_X { + continue 'outer; } + r *= binv.a / (x as f64) - binv.s; + } + break x; + }; - let lambda_l = lambda((f_m - x_l) / (f_m - x_l * p)); - let lambda_r = lambda((x_r - f_m) / (x_r * q)); - // p1 + area of left tail - let p3 = p2 + c / lambda_l; - // p1 + area of right tail - let p4 = p3 + c / lambda_r; - - // return value - let mut y: i64; - - let gen_u = Uniform::new(0., p4).unwrap(); - let gen_v = Uniform::new(0., 1.).unwrap(); - - loop { - // Step 1: Generate `u` for selecting the region. If region 1 is - // selected, generate a triangularly distributed variate. - let u = gen_u.sample(rng); - let mut v = gen_v.sample(rng); - if !(u > p1) { - y = f64_to_i64(x_m - p1 * v + u); - break; - } + if flipped { + binv.n - sample + } else { + sample + } +} - if !(u > p2) { - // Step 2: Region 2, parallelograms. Check if region 2 is - // used. If so, generate `y`. - let x = x_l + (u - p1) / c; - v = v * c + 1.0 - (x - x_m).abs() / p1; - if v > 1. { - continue; - } else { - y = f64_to_i64(x); - } - } else if !(u > p3) { - // Step 3: Region 3, left exponential tail. - y = f64_to_i64(x_l + v.ln() / lambda_l); - if y < 0 { - continue; - } else { - v *= (u - p2) * lambda_l; - } - } else { - // Step 4: Region 4, right exponential tail. - y = f64_to_i64(x_r - v.ln() / lambda_r); - if y > 0 && (y as u64) > self.n { - continue; - } else { - v *= (u - p3) * lambda_r; - } - } +#[allow(clippy::many_single_char_names)] // Same names as in the reference. +fn btpe(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 { + // Threshold for using the squeeze algorithm. This can be freely + // chosen based on performance. Ranlib and GSL use 20. + const SQUEEZE_THRESHOLD: i64 = 20; + + // Step 0: Calculate constants as functions of `n` and `p`. + let n = btpe.n as f64; + let np = n * btpe.p; + let q = 1. - btpe.p; + let npq = np * q; + let f_m = np + btpe.p; + let m = btpe.m; + // radius of triangle region, since height=1 also area of region + let p1 = btpe.p1; + // tip of triangle + let x_m = (m as f64) + 0.5; + // left edge of triangle + let x_l = x_m - p1; + // right edge of triangle + let x_r = x_m + p1; + let c = 0.134 + 20.5 / (15.3 + (m as f64)); + // p1 + area of parallelogram region + let p2 = p1 * (1. + 2. * c); + + fn lambda(a: f64) -> f64 { + a * (1. + 0.5 * a) + } - // Step 5: Acceptance/rejection comparison. - - // Step 5.0: Test for appropriate method of evaluating f(y). - let k = (y - m).abs(); - if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) { - // Step 5.1: Evaluate f(y) via the recursive relationship. Start the - // search from the mode. - let s = p / q; - let a = s * (n + 1.); - let mut f = 1.0; - match m.cmp(&y) { - Ordering::Less => { - let mut i = m; - loop { - i += 1; - f *= a / (i as f64) - s; - if i == y { - break; - } - } - } - Ordering::Greater => { - let mut i = y; - loop { - i += 1; - f /= a / (i as f64) - s; - if i == m { - break; - } - } + let lambda_l = lambda((f_m - x_l) / (f_m - x_l * btpe.p)); + let lambda_r = lambda((x_r - f_m) / (x_r * q)); + + let p3 = p2 + c / lambda_l; + + let p4 = p3 + c / lambda_r; + + // return value + let mut y: i64; + + let gen_u = Uniform::new(0., p4).unwrap(); + let gen_v = Uniform::new(0., 1.).unwrap(); + + loop { + // Step 1: Generate `u` for selecting the region. If region 1 is + // selected, generate a triangularly distributed variate. + let u = gen_u.sample(rng); + let mut v = gen_v.sample(rng); + if !(u > p1) { + y = f64_to_i64(x_m - p1 * v + u); + break; + } + + if !(u > p2) { + // Step 2: Region 2, parallelograms. Check if region 2 is + // used. If so, generate `y`. + let x = x_l + (u - p1) / c; + v = v * c + 1.0 - (x - x_m).abs() / p1; + if v > 1. { + continue; + } else { + y = f64_to_i64(x); + } + } else if !(u > p3) { + // Step 3: Region 3, left exponential tail. + y = f64_to_i64(x_l + v.ln() / lambda_l); + if y < 0 { + continue; + } else { + v *= (u - p2) * lambda_l; + } + } else { + // Step 4: Region 4, right exponential tail. + y = f64_to_i64(x_r - v.ln() / lambda_r); + if y > 0 && (y as u64) > btpe.n { + continue; + } else { + v *= (u - p3) * lambda_r; + } + } + + // Step 5: Acceptance/rejection comparison. + + // Step 5.0: Test for appropriate method of evaluating f(y). + let k = (y - m).abs(); + if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) { + // Step 5.1: Evaluate f(y) via the recursive relationship. Start the + // search from the mode. + let s = btpe.p / q; + let a = s * (n + 1.); + let mut f = 1.0; + match m.cmp(&y) { + Ordering::Less => { + let mut i = m; + loop { + i += 1; + f *= a / (i as f64) - s; + if i == y { + break; } - Ordering::Equal => {} } - if v > f { - continue; - } else { - break; - } - } - - // Step 5.2: Squeezing. Check the value of ln(v) against upper and - // lower bound of ln(f(y)). - let k = k as f64; - let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5); - let t = -0.5 * k * k / npq; - let alpha = v.ln(); - if alpha < t - rho { - break; } - if alpha > t + rho { - continue; + Ordering::Greater => { + let mut i = y; + loop { + i += 1; + f /= a / (i as f64) - s; + if i == m { + break; + } + } } + Ordering::Equal => {} + } + if v > f { + continue; + } else { + break; + } + } - // Step 5.3: Final acceptance/rejection test. - let x1 = (y + 1) as f64; - let f1 = (m + 1) as f64; - let z = (f64_to_i64(n) + 1 - m) as f64; - let w = (f64_to_i64(n) - y + 1) as f64; + // Step 5.2: Squeezing. Check the value of ln(v) against upper and + // lower bound of ln(f(y)). + let k = k as f64; + let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5); + let t = -0.5 * k * k / npq; + let alpha = v.ln(); + if alpha < t - rho { + break; + } + if alpha > t + rho { + continue; + } - fn stirling(a: f64) -> f64 { - let a2 = a * a; - (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320. - } + // Step 5.3: Final acceptance/rejection test. + let x1 = (y + 1) as f64; + let f1 = (m + 1) as f64; + let z = (f64_to_i64(n) + 1 - m) as f64; + let w = (f64_to_i64(n) - y + 1) as f64; - if alpha - > x_m * (f1 / x1).ln() - + (n - (m as f64) + 0.5) * (z / w).ln() - + ((y - m) as f64) * (w * p / (x1 * q)).ln() - // We use the signs from the GSL implementation, which are - // different than the ones in the reference. According to - // the GSL authors, the new signs were verified to be - // correct by one of the original designers of the - // algorithm. - + stirling(f1) - + stirling(z) - - stirling(x1) - - stirling(w) - { - continue; - } + fn stirling(a: f64) -> f64 { + let a2 = a * a; + (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320. + } - break; - } - assert!(y >= 0); - result = y as u64; + if alpha + > x_m * (f1 / x1).ln() + + (n - (m as f64) + 0.5) * (z / w).ln() + + ((y - m) as f64) * (w * btpe.p / (x1 * q)).ln() + // We use the signs from the GSL implementation, which are + // different than the ones in the reference. According to + // the GSL authors, the new signs were verified to be + // correct by one of the original designers of the + // algorithm. + + stirling(f1) + + stirling(z) + - stirling(x1) + - stirling(w) + { + continue; } - // Invert the result for p < 0.5. - if p != self.p { - self.n - result - } else { - result + break; + } + assert!(y >= 0); + let y = y as u64; + + if flipped { + btpe.n - y + } else { + y + } +} + +impl Distribution for Binomial { + fn sample(&self, rng: &mut R) -> u64 { + match self.method { + Method::Binv(binv_para, flipped) => binv(binv_para, flipped, rng), + Method::Btpe(btpe_para, flipped) => btpe(btpe_para, flipped, rng), + Method::Poisson(poisson) => poisson.sample(rng) as u64, + Method::Constant(c) => c, } } } @@ -371,6 +422,8 @@ mod test { test_binomial_mean_and_variance(40, 0.5, &mut rng); test_binomial_mean_and_variance(20, 0.7, &mut rng); test_binomial_mean_and_variance(20, 0.5, &mut rng); + test_binomial_mean_and_variance(1 << 61, 1e-17, &mut rng); + test_binomial_mean_and_variance(u64::MAX, 1e-19, &mut rng); } #[test]