Skip to content

Commit

Permalink
Closes Bears-R-Us#3086: Add choice to random number generators (Bea…
Browse files Browse the repository at this point in the history
…rs-R-Us#3114)

* Closes Bears-R-Us#3086: Add `choice` to random number generators

This PR (closes Bears-R-Us#3086) adds `choice` to random number generators.

I had to make a very slight modification to `sampleDomWeighted` (just moving the increment to be after indexing) to avoid an out of bounds error.

* skip test until issue Bears-R-Us#3118 is resolved

* update tests to use size aggregation instead of count

---------

Co-authored-by: Tess Hayes <[email protected]>
  • Loading branch information
stress-tess and stress-tess authored Apr 25, 2024
1 parent 18edda0 commit 35539e5
Show file tree
Hide file tree
Showing 8 changed files with 344 additions and 51 deletions.
80 changes: 80 additions & 0 deletions PROTO_tests/tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest

import arkouda as ak
from arkouda.scipy import chisquare as akchisquare


class TestRandom:
Expand Down Expand Up @@ -126,6 +127,85 @@ def test_uniform(self):
assert all(bounded_arr.to_ndarray() >= -5)
assert all(bounded_arr.to_ndarray() < 5)

def test_choice_hypothesis_testing(self):
# perform a weighted sample and use chisquare to test
# if the observed frequency matches the expected frequency

# I tested this many times without a set seed, but with no seed
# it's expected to fail one out of every ~20 runs given a pval limit of 0.05.
rng = ak.random.default_rng(43)
num_samples = 10**4

weights = ak.array([0.25, 0.15, 0.20, 0.10, 0.30])
weighted_sample = rng.choice(ak.arange(5), size=num_samples, p=weights)

# count how many of each category we saw
uk, f_obs = ak.GroupBy(weighted_sample).size()

# I think the keys should always be sorted but just in case
if not ak.is_sorted(uk):
f_obs = f_obs[ak.argsort(uk)]

f_exp = weights * num_samples
_, pval = akchisquare(f_obs=f_obs, f_exp=f_exp)

# if pval <= 0.05, the difference from the expected distribution is significant
assert pval > 0.05

def test_choice(self):
# verify without replacement works
rng = ak.random.default_rng()
# test domains and selecting all
domain_choice = rng.choice(20, 20, replace=False)
# since our populations and sample size is the same without replacement,
# we should see all values
assert (ak.sort(domain_choice) == ak.arange(20)).all()

# test arrays and not selecting all
perm = rng.permutation(100)
array_choice = rng.choice(perm, 95, replace=False)
# verify all unique
_, count = ak.GroupBy(array_choice).size()
assert (count == 1).all()

# test single value
scalar = rng.choice(5)
assert type(scalar) is np.int64
assert scalar in [0, 1, 2, 3, 4]

@pytest.mark.skip(reason="skip until issue #3118 is resolved")
def test_choice_flags(self):
# use numpy to randomly generate a set seed
seed = np.random.default_rng().choice(2**63)

rng = ak.random.default_rng(seed)
weights = rng.uniform(size=10)
a_vals = [
10,
rng.integers(0, 2**32, size=10, dtype="uint"),
rng.uniform(-1.0, 1.0, size=10),
rng.integers(0, 1, size=10, dtype="bool"),
rng.integers(-(2**32), 2**32, size=10, dtype="int"),
]

rng = ak.random.default_rng(seed)
choice_arrays = []
for a in a_vals:
for size in 5, 10:
for replace in True, False:
for p in [None, weights]:
choice_arrays.append(rng.choice(a, size, replace, p))

# reset generator to ensure we get the same arrays
rng = ak.random.default_rng(seed)
for a in a_vals:
for size in 5, 10:
for replace in True, False:
for p in [None, weights]:
previous = choice_arrays.pop(0)
current = rng.choice(a, size, replace, p)
assert np.allclose(previous.to_list(), current.to_list())

def test_legacy_randint(self):
testArray = ak.random.randint(0, 10, 5)
assert isinstance(testArray, ak.pdarray)
Expand Down
84 changes: 83 additions & 1 deletion arkouda/random/_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,88 @@ def __str__(self):
_str += "(PCG64)"
return _str

def choice(self, a, size=None, replace=True, p=None):
"""
Generates a randomly sample from a.
Parameters
----------
a: int or pdarray
If a is an integer, randomly sample from ak.arange(a).
If a is a pdarray, randomly sample from a.
size: int, optional
Number of elements to be sampled
replace: bool, optional
If True, sample with replacement. Otherwise sample without replacement.
Defaults to True
p: pdarray, optional
p is the probabilities or weights associated with each element of a
Returns
-------
pdarray, numeric_scalar
A pdarray containing the sampled values or a single random value if size not provided.
"""
if size is None:
ret_scalar = True
size = 1
else:
ret_scalar = False

from arkouda.numeric import cast as akcast

if _val_isinstance_of_union(a, int_scalars):
is_domain = True
dtype = to_numpy_dtype(akint64)
pop_size = a
elif isinstance(a, pdarray):
is_domain = False
dtype = to_numpy_dtype(a.dtype)
pop_size = a.size
else:
raise TypeError("choice only accepts a pdarray or int scalar.")

if not replace and size > pop_size:
raise ValueError("Cannot take a larger sample than population when replace is False")

has_weights = p is not None
if has_weights:
if not isinstance(p, pdarray):
raise TypeError("weights must be a pdarray")
if p.dtype != akfloat64:
p = akcast(p, akfloat64)
else:
p = ""

# weighted sample requires float and non-weighted uses int
name = self._name_dict[to_numpy_dtype(akfloat64 if has_weights else akint64)]

rep_msg = generic_msg(
cmd="choice",
args={
"gName": name,
"aName": a,
"wName": p,
"numSamples": size,
"replace": replace,
"hasWeights": has_weights,
"isDom": is_domain,
"popSize": pop_size,
"dtype": dtype,
"state": self._state,
},
)
# for the non-weighted domain case we pull pop_size numbers from the generator.
# for other cases we may be more than the numbers drawn, but that's okay. The important
# thing is not repeating any positions in the state.
self._state += pop_size

pda = create_pdarray(rep_msg)
return pda if not ret_scalar else pda[0]

def integers(self, low, high=None, size=None, dtype=akint64, endpoint=False):
"""
Return random integers from low (inclusive) to high (exclusive),
Expand Down Expand Up @@ -267,7 +349,7 @@ def permutation(self, x):
dtype = to_numpy_dtype(x.dtype)
size = x.size
else:
raise TypeError("permtation only accepts a pdarray or int scalar.")
raise TypeError("permutation only accepts a pdarray or int scalar.")

# we have to use the int version since we permute the domain
name = self._name_dict[to_numpy_dtype(akint64)]
Expand Down
2 changes: 1 addition & 1 deletion src/RandArray.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ module RandArray {
// add the sampled index to the list of indices and the list of samples
weightsCopy[ii] = 0;
indices += ii;
i += 1;
samples[i] = ii;
i += 1;

// recompute the normalized cumulative weights
cw = + scan weightsCopy;
Expand Down
135 changes: 87 additions & 48 deletions src/RandMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,92 @@ module RandMsg
return new MsgTuple(repMsg, MsgType.NORMAL);
}

proc choiceMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
const pn = Reflection.getRoutineName(),
gName = msgArgs.getValueOf("gName"), // generator name
aName = msgArgs.getValueOf("aName"), // values array name
wName = msgArgs.getValueOf("wName"), // weights array name
numSamples = msgArgs.get("numSamples").getIntValue(), // number of samples
replace = msgArgs.get("replace").getBoolValue(), // sample with replacement
hasWeights = msgArgs.get("hasWeights").getBoolValue(), // flag indicating whether weighted sample
isDom = msgArgs.get("isDom").getBoolValue(), // flag indicating whether return is domain or array
popSize = msgArgs.get("popSize").getIntValue(), // population size
dtypeStr = msgArgs.getValueOf("dtype"), // string version of dtype
dtype = str2dtype(dtypeStr), // DType enum
state = msgArgs.get("state").getIntValue(), // rng state
rname = st.nextName();

randLogger.debug(getModuleName(),pn,getLineNumber(),
"gname: %? aname %? wname: %? numSamples %i replace %i hasWeights %i isDom %i dtype %? popSize %? state %i rname %?"
.doFormat(gName, aName, wName, numSamples, replace, hasWeights, isDom, dtypeStr, popSize, state, rname));

proc weightedIdxHelper() throws {
var generatorEntry = toGeneratorSymEntry(st.lookup(gName), real);
ref rng = generatorEntry.generator;

if state != 1 then rng.skipTo(state-1);

st.checkTable(wName);
const weights = toSymEntry(getGenericTypedArrayEntry(wName, st),real).a;
return sampleDomWeighted(rng, numSamples, weights, replace);
}

proc idxHelper() throws {
var generatorEntry = toGeneratorSymEntry(st.lookup(gName), int);
ref rng = generatorEntry.generator;

if state != 1 then rng.skipTo(state-1);

const choiceDom = {0..<popSize};
return rng.sample(choiceDom, numSamples, replace);
}

proc choiceHelper(type t) throws {
// I had to break these 2 helpers out into seprate functions since they have different types for generatorEntry
const choiceIdx = if hasWeights then weightedIdxHelper() else idxHelper();

if isDom {
const choiceEntry = createSymEntry(choiceIdx);
st.addEntry(rname, choiceEntry);
}
else {
var choiceArr: [makeDistDom(numSamples)] t;
st.checkTable(aName);
const myArr = toSymEntry(getGenericTypedArrayEntry(aName, st),t).a;

forall (ca,idx) in zip(choiceArr, choiceIdx) with (var agg = newSrcAggregator(t)) {
agg.copy(ca, myArr[idx]);
}

const choiceEntry = createSymEntry(choiceArr);
st.addEntry(rname, choiceEntry);
}
const repMsg = "created " + st.attrib(rname);
randLogger.debug(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

select dtype {
when DType.Int64 {
return choiceHelper(int);
}
when DType.UInt64 {
return choiceHelper(uint);
}
when DType.Float64 {
return choiceHelper(real);
}
when DType.Bool {
return choiceHelper(bool);
}
otherwise {
const errorMsg = "Unhandled data type %s".doFormat(dtypeStr);
randLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(notImplementedError(pn, errorMsg), MsgType.ERROR);
}
}
}

proc permutationMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
const pn = Reflection.getRoutineName();
var rname = st.nextName();
Expand Down Expand Up @@ -419,58 +505,11 @@ module RandMsg
return new MsgTuple(repMsg, MsgType.NORMAL);
}

proc sampleWeightsMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
const pn = Reflection.getRoutineName(),
gName = msgArgs.getValueOf("gName"), // generator name
aName = msgArgs.getValueOf("aName"), // values array name
wName = msgArgs.getValueOf("wName"), // weights array name
n = msgArgs.get("n").getIntValue(), // number of samples
replace = msgArgs.get("replace").getBoolValue(), // sample with replacement?
state = msgArgs.get("state").getIntValue(), // rng state
rname = st.nextName();

randLogger.debug(getModuleName(),pn,getLineNumber(),
"gname: %? aname %? wname: %? n %i replace %i state %i rname %?"
.doFormat(gName, aName, wName, n, replace, state, rname));

var aGEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(aName, st),
wGEnt: borrowed GenSymEntry = getGenericTypedArrayEntry(wName, st);

proc sampleHelper(type t): MsgTuple throws {
const aE = toSymEntry(aGEnt, t),
wE = toSymEntry(wGEnt, real); // weights are always real

var generatorEntry: borrowed GeneratorSymEntry(real) = toGeneratorSymEntry(st.lookup(gName), real);
ref rng = generatorEntry.generator;
if state != 1 then rng.skipTo(state-1);

const s = randSampleWeights(rng, aE.a, wE.a, n, replace);
st.addEntry(rname, createSymEntry(s));

const repMsg = "created " + st.attrib(rname);
randLogger.debug(getModuleName(),pn,getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

select aGEnt.dtype {
when DType.Int64 do return sampleHelper(int);
when DType.UInt64 do return sampleHelper(uint);
when DType.Float64 do return sampleHelper(real);
when DType.Bool do return sampleHelper(bool);
otherwise {
const errorMsg = "Unhandled data type %s".doFormat(dtype2str(aGEnt.dtype));
randLogger.error(getModuleName(),pn,getLineNumber(),errorMsg);
return new MsgTuple(notImplementedError(pn, errorMsg), MsgType.ERROR);
}
}

}

use CommandMap;
registerFunction("randomNormal", randomNormalMsg, getModuleName());
registerFunction("createGenerator", createGeneratorMsg, getModuleName());
registerFunction("uniformGenerator", uniformGeneratorMsg, getModuleName());
registerFunction("choice", choiceMsg, getModuleName());
registerFunction("permutation", permutationMsg, getModuleName());
registerFunction("shuffle", shuffleMsg, getModuleName());
registerFunction("sampleWeights", sampleWeightsMsg, getModuleName());
}
4 changes: 4 additions & 0 deletions src/compat/e-132/ArkoudaRandomCompat.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ module ArkoudaRandomCompat {
r.permutation(domArr);
return domArr;
}

proc ref sample(d: domain(?), n: int, withReplacement = false): [] d.idxType throws where is1DRectangularDomain(d) {
return r.choice(d, n, withReplacement);
}
proc ref next(): eltType do return r.getNext();
proc skipTo(n: int) do try! r.skipToNth(n);
}
Expand Down
3 changes: 3 additions & 0 deletions src/compat/eq-131/ArkoudaRandomCompat.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ module ArkoudaRandomCompat {
r.permutation(domArr);
return domArr;
}
proc ref sample(d: domain, n: int, withReplacement = false): [] d.idxType throws where is1DRectangularDomain(d) {
return r.choice(d, n, withReplacement);
}
proc ref next(): eltType do return r.getNext();
proc skipTo(n: int) do try! r.skipToNth(n);
}
Expand Down
5 changes: 4 additions & 1 deletion src/compat/eq-133/ArkoudaRandomCompat.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ module ArkoudaRandomCompat {
this.permutation(domArr);
return domArr;
}

proc ref randomStream.sample(d: domain(?), n: int, withReplacement = false): [] d.idxType throws where is1DRectangularDomain(d) && isCoercible(this.eltType, d.idxType) {
// unfortunately there isn't a domain permutation function so we will create an array to permute
return this.choice(d, n, withReplacement);
}
proc ref randomStream.next() do return this.getNext();
}
Loading

0 comments on commit 35539e5

Please sign in to comment.