Skip to content

Commit

Permalink
Closes Bears-R-Us#3245: Add poisson distribution to random number g…
Browse files Browse the repository at this point in the history
…enerators (Bears-R-Us#3253)

* Closes Bears-R-Us#3245: Add `poisson` distribution to random number generators

This PR (closes Bears-R-Us#3245) adds the `poisson` distribution to random number generators. This also adds the docs and testing including hypothesis testing against the probabilities expected from the poisson probability mass function. This PR does a coforall over the locals and tasks so I can ensure each task gets a unique seed that is reproducable

* update compat modules

* remove unused import

* forgot to add set seed back to hypothesis test

* update in reponse to pr feedback

---------

Co-authored-by: Tess Hayes <[email protected]>
  • Loading branch information
stress-tess and stress-tess authored Jun 3, 2024
1 parent cd49038 commit ace1e32
Show file tree
Hide file tree
Showing 8 changed files with 319 additions and 52 deletions.
95 changes: 70 additions & 25 deletions PROTO_tests/tests/random_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import Counter

import numpy as np
import pytest
from scipy import stats as sp_stats
Expand Down Expand Up @@ -128,31 +130,6 @@ def test_uniform(self):
assert all(bounded_arr.to_ndarray() >= -5)
assert all(bounded_arr.to_ndarray() < 5)

def test_choice_hypothesis_testing(self):
# perform a weighted sample and use chisquare to test
# if the observed frequency matches the expected frequency

# I tested this many times without a set seed, but with no seed
# it's expected to fail one out of every ~20 runs given a pval limit of 0.05.
rng = ak.random.default_rng(43)
num_samples = 10**4

weights = ak.array([0.25, 0.15, 0.20, 0.10, 0.30])
weighted_sample = rng.choice(ak.arange(5), size=num_samples, p=weights)

# count how many of each category we saw
uk, f_obs = ak.GroupBy(weighted_sample).size()

# I think the keys should always be sorted but just in case
if not ak.is_sorted(uk):
f_obs = f_obs[ak.argsort(uk)]

f_exp = weights * num_samples
_, pval = akchisquare(f_obs=f_obs, f_exp=f_exp)

# if pval <= 0.05, the difference from the expected distribution is significant
assert pval > 0.05

def test_choice(self):
# verify without replacement works
rng = ak.random.default_rng()
Expand Down Expand Up @@ -223,6 +200,47 @@ def test_normal(self):
== both_array
)

def test_poissson(self):
rng = ak.random.default_rng(17)
num_samples = 5
# scalar lambda
scal_lam = 2
scal_sample = rng.poisson(lam=scal_lam, size=num_samples).to_list()

# array lambda
arr_lam = ak.arange(5)
arr_sample = rng.poisson(lam=arr_lam, size=num_samples).to_list()

# reset rng with same seed and ensure we get same results
rng = ak.random.default_rng(17)
assert rng.poisson(lam=scal_lam, size=num_samples).to_list() == scal_sample
assert rng.poisson(lam=arr_lam, size=num_samples).to_list() == arr_sample

def test_choice_hypothesis_testing(self):
# perform a weighted sample and use chisquare to test
# if the observed frequency matches the expected frequency

# I tested this many times without a set seed, but with no seed
# it's expected to fail one out of every ~20 runs given a pval limit of 0.05.
rng = ak.random.default_rng(43)
num_samples = 10**4

weights = ak.array([0.25, 0.15, 0.20, 0.10, 0.30])
weighted_sample = rng.choice(ak.arange(5), size=num_samples, p=weights)

# count how many of each category we saw
uk, f_obs = ak.GroupBy(weighted_sample).size()

# I think the keys should always be sorted but just in case
if not ak.is_sorted(uk):
f_obs = f_obs[ak.argsort(uk)]

f_exp = weights * num_samples
_, pval = akchisquare(f_obs=f_obs, f_exp=f_exp)

# if pval <= 0.05, the difference from the expected distribution is significant
assert pval > 0.05

def test_normal_hypothesis_testing(self):
# I tested this many times without a set seed, but with no seed
# it's expected to fail one out of every ~20 runs given a pval limit of 0.05.
Expand All @@ -246,6 +264,33 @@ def test_normal_hypothesis_testing(self):
)
assert good_fit_res.pvalue > 0.05

def test_poisson_hypothesis_testing(self):
# I tested this many times without a set seed, but with no seed
# it's expected to fail one out of every ~20 runs given a pval limit of 0.05.
rng = ak.random.default_rng(43)
num_samples = 10**4
lam = rng.uniform(0, 10)

sample = rng.poisson(lam=lam, size=num_samples)
count_dict = Counter(sample.to_list())

# the sum of exp freq must be within 1e-08, so use the cdf to find out how many
# elements we need to ensure we're within that tolerance
tol = 1e-09
num_elems = 5
while (1 - sp_stats.poisson.cdf(num_elems, mu=lam)) > tol:
num_elems += 5

obs_counts = np.array([0] * num_elems)
for k, v in count_dict.items():
obs_counts[k] = v

# use the probability mass function to get the probability of seeing each value
# and multiply by num_samples to get the expected counts
exp_counts = sp_stats.poisson.pmf(range(num_elems), mu=lam) * num_samples
_, pval = sp_stats.chisquare(f_obs=obs_counts, f_exp=exp_counts)
assert pval > 0.05

def test_legacy_randint(self):
testArray = ak.random.randint(0, 10, 5)
assert isinstance(testArray, ak.pdarray)
Expand Down
78 changes: 76 additions & 2 deletions arkouda/random/_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from arkouda.dtypes import bool as akbool
from arkouda.dtypes import dtype as to_numpy_dtype
from arkouda.dtypes import float64 as akfloat64
from arkouda.dtypes import float_scalars
from arkouda.dtypes import int64 as akint64
from arkouda.dtypes import int_scalars
from arkouda.dtypes import int_scalars, numeric_scalars
from arkouda.dtypes import uint64 as akuint64
from arkouda.pdarrayclass import create_pdarray, pdarray

Expand Down Expand Up @@ -354,7 +355,7 @@ def standard_normal(self, size=None):
},
)
# since we generate 2*size uniform samples for box-muller transform
self._state += (size * 2)
self._state += size * 2
return create_pdarray(rep_msg)

def shuffle(self, x):
Expand Down Expand Up @@ -429,6 +430,79 @@ def permutation(self, x):
self._state += size
return create_pdarray(rep_msg)

def poisson(self, lam=1.0, size=None):
r"""
Draw samples from a Poisson distribution.
The Poisson distribution is the limit of the binomial distribution for large N.
Parameters
----------
lam: float or pdarray
Expected number of events occurring in a fixed-time interval, must be >= 0.
An array must have the same size as the size argument.
size: numeric_scalars, optional
Output shape. Default is None, in which case a single value is returned.
Notes
-----
The probability mass function for the Poisson distribution is:
.. math::
f(k; \lambda) = \frac{\lambda^k e^{-\lambda}}{k!}
For events with an expected separation :math:`\lambda`, the Poisson distribution
:math:`f(k; \lambda)` describes the probability of :math:`k` events occurring
within the observed interval :math:`\lambda`
Returns
-------
pdarray
Pdarray of ints (unless size=None, in which case a single int is returned).
Examples
--------
>>> rng = ak.random.default_rng()
>>> rng.poisson(lam=3, size=5)
array([5 3 2 2 3]) # random
"""
if size is None:
# delegate to numpy when return size is 1
return self._np_generator.poisson(lam, size)

if _val_isinstance_of_union(lam, numeric_scalars):
is_single_lambda = True
if not _val_isinstance_of_union(lam, float_scalars):
lam = float(lam)
if lam < 0:
raise TypeError("lambda must be >=0")
elif isinstance(lam, pdarray):
is_single_lambda = False
if size != lam.size:
raise TypeError("array of lambdas must have same size as return size")
if lam.dtype != akfloat64:
from arkouda.numeric import cast as akcast

lam = akcast(lam, akfloat64)
if (lam < 0).any():
raise TypeError("all lambdas must be >=0")
else:
raise TypeError("poisson only accepts a pdarray or float scalar for lam")

rep_msg = generic_msg(
cmd="poissonGenerator",
args={
"name": self._name_dict[akfloat64],
"lam": lam,
"is_single_lambda": is_single_lambda,
"size": size,
"state": self._state,
},
)
# we only generate one val using the generator in the symbol table
self._state += 1
return create_pdarray(rep_msg)

def uniform(self, low=0.0, high=1.0, size=None):
"""
Draw samples from a uniform distribution.
Expand Down
4 changes: 4 additions & 0 deletions pydoc/usage/random.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ permutation
---------
.. autofunction:: arkouda.random.Generator.permutation

poisson
---------
.. autofunction:: arkouda.random.Generator.poisson

shuffle
---------
.. autofunction:: arkouda.random.Generator.shuffle
Expand Down
96 changes: 96 additions & 0 deletions src/RandMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,101 @@ module RandMsg
return new MsgTuple(repMsg, MsgType.NORMAL);
}

proc poissonGeneratorMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
const pn = Reflection.getRoutineName(),
name = msgArgs.getValueOf("name"), // generator name
isSingleLam = msgArgs.get("is_single_lambda").getBoolValue(), // boolean indicated if lambda is a single value or array
lamStr = msgArgs.getValueOf("lam"), // lambda for poisson distribution
size = msgArgs.get("size").getIntValue(), // number of values to be generated
state = msgArgs.get("state").getIntValue(), // rng state
rname = st.nextName();


randLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),
"name: %? size %i isSingleLam %? lamStr %? state %i".doFormat(name, size, isSingleLam, lamStr, state));

st.checkTable(name);

var generatorEntry: borrowed GeneratorSymEntry(real) = toGeneratorSymEntry(st.lookup(name), real);
ref rng = generatorEntry.generator;
if state != 1 {
// you have to skip to one before where you want to be
rng.skipTo(state-1);
}

// uses the algorithm from knuth found here:
// https://en.wikipedia.org/wiki/Poisson_distribution#Random_variate_generation
// generates values drawn from poisson distribution using a stream of uniformly distributed random numbers

const nTasksPerLoc = here.maxTaskPar; // tasks per locale based on locale0
const Tasks = {0..#nTasksPerLoc}; // these need to be const for comms/performance reasons

const generatorSeed = (rng.next() * 2**62):int;
var poissonArr = makeDistArray(size, int);

// I hate the code duplication here but it's not immediately obvious to me how to avoid it
if isSingleLam {
const lam = lamStr:real;
// using nested coforall over locales and tasks so we know how to generate taskSeed
for loc in Locales do on loc {
const generatorIdxOffset = here.id * nTasksPerLoc,
locSubDom = poissonArr.localSubdomain(), // the chunk that this locale needs to handle
indicesPerTask = locSubDom.size / nTasksPerLoc; // the number of elements each task needs to handle

coforall tid in Tasks {
const taskSeed = generatorSeed + generatorIdxOffset + tid, // initial seed offset by other locales threads plus current thread id
startIdx = tid * indicesPerTask,
stopIdx = if tid == nTasksPerLoc - 1 then locSubDom.size else (tid + 1) * indicesPerTask; // the last task picks up the remainder of indices
var rs = new randomStream(real, taskSeed);
for i in startIdx..<stopIdx {
var L = exp(-lam);
var k = 0;
var p = 1.0;

do {
k += 1;
p = p * rs.next(0, 1);
} while p > L;
poissonArr[locSubDom.low + i] = k - 1;
}
}
}
}
else {
st.checkTable(lamStr);
const lamArr = toSymEntry(getGenericTypedArrayEntry(lamStr, st),real).a;
// using nested coforall over locales and task so we know exactly how many generators we need
for loc in Locales do on loc {
const generatorIdxOffset = here.id * nTasksPerLoc,
locSubDom = poissonArr.localSubdomain(), // the chunk that this locale needs to handle
indicesPerTask = locSubDom.size / nTasksPerLoc; // the number of elements each task needs to handle

coforall tid in Tasks {
const taskSeed = generatorSeed + generatorIdxOffset + tid,
startIdx = tid * indicesPerTask,
stopIdx = if tid == nTasksPerLoc - 1 then locSubDom.size else (tid + 1) * indicesPerTask; // the last task picks up the remainder of indices
var rs = new randomStream(real, taskSeed);
for i in startIdx..<stopIdx {
const lam = lamArr[locSubDom.low + i];
var L = exp(-lam);
var k = 0;
var p = 1.0;

do {
k += 1;
p = p * rs.next(0, 1);
} while p > L;
poissonArr[locSubDom.low + i] = k - 1;
}
}
}
}
st.addEntry(rname, createSymEntry(poissonArr));
const repMsg = "created " + st.attrib(rname);
randLogger.debug(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

proc shuffleMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
const pn = Reflection.getRoutineName();
const name = msgArgs.getValueOf("name");
Expand Down Expand Up @@ -606,5 +701,6 @@ module RandMsg
registerFunction("segmentedSample", segmentedSampleMsg, getModuleName());
registerFunction("choice", choiceMsg, getModuleName());
registerFunction("permutation", permutationMsg, getModuleName());
registerFunction("poissonGenerator", poissonGeneratorMsg, getModuleName());
registerFunction("shuffle", shuffleMsg, getModuleName());
}
1 change: 1 addition & 0 deletions src/compat/e-132/ArkoudaRandomCompat.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ module ArkoudaRandomCompat {
return x[idx];
}
proc ref next(): eltType do return r.getNext();
proc ref next(min: eltType, max: eltType): eltType do return r.getNext(min, max);
proc skipTo(n: int) do try! r.skipToNth(n);
}

Expand Down
1 change: 1 addition & 0 deletions src/compat/eq-131/ArkoudaRandomCompat.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ module ArkoudaRandomCompat {
return x[idx];
}
proc ref next(): eltType do return r.getNext();
proc ref next(min: eltType, max: eltType): eltType do return r.getNext(min, max);
proc skipTo(n: int) do try! r.skipToNth(n);
}

Expand Down
1 change: 1 addition & 0 deletions src/compat/eq-133/ArkoudaRandomCompat.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ module ArkoudaRandomCompat {
return x[idx];
}
proc ref randomStream.next() do return this.getNext();
proc ref randomStream.next(min: eltType, max: eltType): eltType do return this.getNext(min, max);

proc choiceUniform(ref stream, X: domain(?), size: ?sizeType, replace: bool) throws
{
Expand Down
Loading

0 comments on commit ace1e32

Please sign in to comment.