diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 5651ee1b8..bc07e050b 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -29,7 +29,7 @@ def main( step_per_epoch: int = 30000, step_per_collect: int = 2048, repeat_per_collect: int = 1, - batch_size: int = 16, + batch_size: int | None = None, training_num: int = 10, test_num: int = 10, rew_norm: bool = True, diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index 48dde374d..43a4db2e9 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -37,12 +37,17 @@ class SamplingConfig(ToStringMixin): an explanation of epoch semantics. """ - batch_size: int = 64 + batch_size: int | None = 64 """for off-policy algorithms, this is the number of environment steps/transitions to sample from the buffer for a gradient update; for on-policy algorithms, its use is algorithm-specific. On-policy algorithms use the full buffer that was collected in the preceding collection step but they may use this parameter to perform the gradient update using mini-batches of this size (causing the gradient to be less accurate, a form of regularization). + + ``batch_size=None`` means that the full buffer is used for the gradient update. This doesn't + make much sense for off-policy algorithms and is not recommended then. For on-policy or offline algorithms, + this means that the full buffer is used for the gradient update (no mini-batching), and + may make sense in some cases. """ num_train_envs: int = -1 diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 3fb300261..eceee100f 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -448,7 +448,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -#TODO: fix docstring +# TODO: fix docstring class BranchingNet(NetBase[Any]): """Branching dual Q network.