Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve environment reset speed by changing jax.random.choice #96

Open
dluo96 opened this issue Mar 28, 2023 · 0 comments
Open

feat: improve environment reset speed by changing jax.random.choice #96

dluo96 opened this issue Mar 28, 2023 · 0 comments
Labels
enhancement New feature or request

Comments

@dluo96
Copy link
Contributor

dluo96 commented Mar 28, 2023

Is your feature request related to a problem? Please describe

jax.random.choice seems to be slow especially when sampling without replacement. Sampling with replacement seems to be much faster, but even without replacement is probably slower than jax.random.randint (to be verified).

Describe the solution you'd like

Find a way to use jax.random.choice(replace=False) as little as possible to improve environment speed.

Alternatives considered

As a first study towards this, it turns out that jax.random.choice(..., replace=True) is faster than jax.random.categorical for sampling with replacement. jax.random.choice(..., replace=False) appears much slower than the other two. When sampling without replacement is needed, we still have to study what the best approach is.
Source: notebook

Alternatives to jax.random.choice(..., replace=False) that could be considered and assessed include:

  • Sampling once from the joint distribution where the joint gathers all the valid pairs
  • Creating two random partitions p1 and p2, sampling one index i only, and get the two samples by p1[i] and p2[i]
  • Sequentially sampling one index and then a second one from the conditional distribution given the first one is not available

It may be that jax.random.choice(..., replace=False) ends up being the most optimised version. In any case, the solution may depend on how many samples we need to sample without replacement (e.g. 2 in the case of Snake).

Remarks

We need to take this into account for random policies. It is likely that the random action selection influences the environment speed by a lot, hence biasing speed benchmarks.

@dluo96 dluo96 added the enhancement New feature or request label Mar 28, 2023
@dluo96 dluo96 closed this as completed Mar 28, 2023
@dluo96 dluo96 reopened this Mar 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant