Skip to content

Commit

Permalink
Merge pull request #834 from mlcommons/random_utils
Browse files Browse the repository at this point in the history
Fixes to random_utils.py
  • Loading branch information
priyakasimbeg authored Jan 18, 2025
2 parents fe90379 + fc526a4 commit 6c8fd56
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions algorithmic_efficiency/random_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,32 @@

FLAGS = flags.FLAGS

# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an
# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 31 - 1] (an
# unsigned int), while RandomState.randint only accepts and returns signed ints.
MAX_UINT32 = 2**32 - 1
MIN_UINT32 = 0
MAX_INT32 = 2**31 - 1
MIN_INT32 = 0

SeedType = Union[int, list, np.ndarray]


def _signed_to_unsigned(seed: SeedType) -> SeedType:
if isinstance(seed, int):
return seed % MAX_UINT32
return seed % MAX_INT32
if isinstance(seed, list):
return [s % MAX_UINT32 for s in seed]
return [s % MAX_INT32 for s in seed]
if isinstance(seed, np.ndarray):
return np.array([s % MAX_UINT32 for s in seed.tolist()])
return np.array([s % MAX_INT32 for s in seed.tolist()])


def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32)
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
return [new_seed, data]


def _split(seed: SeedType, num: int = 2) -> SeedType:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2])
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])


def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
Expand Down

0 comments on commit 6c8fd56

Please sign in to comment.