Skip to content

Commit

Permalink
SamplingConfig: support for batch_size=None
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Panchenko committed Apr 3, 2024
1 parent e8ba5ad commit e4d7d2f
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_reinforce_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def main(
step_per_epoch: int = 30000,
step_per_collect: int = 2048,
repeat_per_collect: int = 1,
batch_size: int = 16,
batch_size: int | None = None,
training_num: int = 10,
test_num: int = 10,
rew_norm: bool = True,
Expand Down
7 changes: 6 additions & 1 deletion tianshou/highlevel/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,17 @@ class SamplingConfig(ToStringMixin):
an explanation of epoch semantics.
"""

batch_size: int = 64
batch_size: int | None = 64
"""for off-policy algorithms, this is the number of environment steps/transitions to sample
from the buffer for a gradient update; for on-policy algorithms, its use is algorithm-specific.
On-policy algorithms use the full buffer that was collected in the preceding collection step
but they may use this parameter to perform the gradient update using mini-batches of this size
(causing the gradient to be less accurate, a form of regularization).
``batch_size=None`` means that the full buffer is used for the gradient update. This doesn't
make much sense for off-policy algorithms and is not recommended then. For on-policy or offline algorithms,
this means that the full buffer is used for the gradient update (no mini-batching), and
may make sense in some cases.
"""

num_train_envs: int = -1
Expand Down
2 changes: 1 addition & 1 deletion tianshou/utils/net/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


#TODO: fix docstring
# TODO: fix docstring
class BranchingNet(NetBase[Any]):
"""Branching dual Q network.
Expand Down

0 comments on commit e4d7d2f

Please sign in to comment.