From 0ceb5665a8e109ec48f7c5700f6c23227e654081 Mon Sep 17 00:00:00 2001 From: tess <48131946+stress-tess@users.noreply.github.com> Date: Tue, 4 Jun 2024 16:33:01 -0400 Subject: [PATCH] poisson cleanup (#3280) * small updates found from team code share / knowledge transfer * missed a spot --------- Co-authored-by: Tess Hayes --- PROTO_tests/tests/random_test.py | 10 ++++------ src/RandMsg.chpl | 8 ++++---- tests/random_test.py | 10 ++++------ 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/PROTO_tests/tests/random_test.py b/PROTO_tests/tests/random_test.py index 29f2a42c4e..d02b69ac59 100644 --- a/PROTO_tests/tests/random_test.py +++ b/PROTO_tests/tests/random_test.py @@ -274,12 +274,10 @@ def test_poisson_hypothesis_testing(self): sample = rng.poisson(lam=lam, size=num_samples) count_dict = Counter(sample.to_list()) - # the sum of exp freq must be within 1e-08, so use the cdf to find out how many - # elements we need to ensure we're within that tolerance - tol = 1e-09 - num_elems = 5 - while (1 - sp_stats.poisson.cdf(num_elems, mu=lam)) > tol: - num_elems += 5 + # the sum of exp freq and obs freq must be within 1e-08, so we use + # the isf (inverse survival function where survival function is 1-cdf) to + # find out how many elements we need to ensure we're within that tolerance + num_elems = int(sp_stats.poisson.isf(1e-09, mu=lam)) obs_counts = np.array([0] * num_elems) for k, v in count_dict.items(): diff --git a/src/RandMsg.chpl b/src/RandMsg.chpl index 3e2a7cd6ad..00318bfcc4 100644 --- a/src/RandMsg.chpl +++ b/src/RandMsg.chpl @@ -582,8 +582,8 @@ module RandMsg // I hate the code duplication here but it's not immediately obvious to me how to avoid it if isSingleLam { const lam = lamStr:real; - // using nested coforall over locales and tasks so we know how to generate taskSeed - for loc in Locales do on loc { + // using nested coforalls over locales and tasks so we know how to generate taskSeed + coforall loc in Locales do on loc { const generatorIdxOffset = here.id * nTasksPerLoc, locSubDom = poissonArr.localSubdomain(), // the chunk that this locale needs to handle indicesPerTask = locSubDom.size / nTasksPerLoc; // the number of elements each task needs to handle @@ -610,8 +610,8 @@ module RandMsg else { st.checkTable(lamStr); const lamArr = toSymEntry(getGenericTypedArrayEntry(lamStr, st),real).a; - // using nested coforall over locales and task so we know exactly how many generators we need - for loc in Locales do on loc { + // using nested coforalls over locales and task so we know exactly how many generators we need + coforall loc in Locales do on loc { const generatorIdxOffset = here.id * nTasksPerLoc, locSubDom = poissonArr.localSubdomain(), // the chunk that this locale needs to handle indicesPerTask = locSubDom.size / nTasksPerLoc; // the number of elements each task needs to handle diff --git a/tests/random_test.py b/tests/random_test.py index 3679eff31f..a0f76738c0 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -281,12 +281,10 @@ def test_poisson_hypothesis_testing(self): sample = rng.poisson(lam=lam, size=num_samples) count_dict = Counter(sample.to_list()) - # the sum of exp freq must be within 1e-08, so use the cdf to find out how many - # elements we need to ensure we're within that tolerance - tol = 1e-09 - num_elems = 5 - while (1 - sp_stats.poisson.cdf(num_elems, mu=lam)) > tol: - num_elems += 5 + # the sum of exp freq and obs freq must be within 1e-08, so we use + # the isf (inverse survival function where survival function is 1-cdf) to + # find out how many elements we need to ensure we're within that tolerance + num_elems = int(sp_stats.poisson.isf(1e-09, mu=lam)) obs_counts = np.array([0] * num_elems) for k, v in count_dict.items():