diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index b1d8791f..42529576 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -613,6 +613,7 @@ def fit_online( random_steps: int = 0, eval_env: Optional[GymEnv] = None, eval_epsilon: float = 0.0, + eval_n_trials: int = 10, save_interval: int = 1, experiment_name: Optional[str] = None, with_timestamp: bool = True, @@ -764,7 +765,10 @@ def fit_online( # evaluation if eval_env: eval_score = evaluate_qlearning_with_environment( - self, eval_env, epsilon=eval_epsilon + self, + eval_env, + n_trials=eval_n_trials, + epsilon=eval_epsilon, ) logger.add_metric("evaluation", eval_score) diff --git a/d3rlpy/dataset/replay_buffer.py b/d3rlpy/dataset/replay_buffer.py index dfe68262..0b70f5b6 100644 --- a/d3rlpy/dataset/replay_buffer.py +++ b/d3rlpy/dataset/replay_buffer.py @@ -751,6 +751,7 @@ def create_fifo_replay_buffer( trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, writer_preprocessor: Optional[WriterPreprocessProtocol] = None, env: Optional[GymEnv] = None, + write_at_termination: bool = False, ) -> ReplayBuffer: """Builds FIFO replay buffer. @@ -770,6 +771,8 @@ def create_fifo_replay_buffer( Writer preprocessor implementation. If ``None`` is given, ``BasicWriterPreprocess`` is used by default. env: Gym environment to extract shapes of observations and action. + write_at_termination (bool): Flag to write experiences to the buffer at the + end of an episode all at once. Returns: Replay buffer. @@ -782,6 +785,7 @@ def create_fifo_replay_buffer( trajectory_slicer=trajectory_slicer, writer_preprocessor=writer_preprocessor, env=env, + write_at_termination=write_at_termination, ) @@ -791,6 +795,7 @@ def create_infinite_replay_buffer( trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, writer_preprocessor: Optional[WriterPreprocessProtocol] = None, env: Optional[GymEnv] = None, + write_at_termination: bool = False, ) -> ReplayBuffer: """Builds infinite replay buffer. @@ -809,6 +814,8 @@ def create_infinite_replay_buffer( Writer preprocessor implementation. If ``None`` is given, ``BasicWriterPreprocess`` is used by default. env: Gym environment to extract shapes of observations and action. + write_at_termination (bool): Flag to write experiences to the buffer at the + end of an episode all at once. Returns: Replay buffer. @@ -821,4 +828,5 @@ def create_infinite_replay_buffer( trajectory_slicer=trajectory_slicer, writer_preprocessor=writer_preprocessor, env=env, + write_at_termination=write_at_termination, ) diff --git a/reproductions/finetuning/cal_ql_finetune.py b/reproductions/finetuning/cal_ql_finetune.py index e7131235..35a492d2 100644 --- a/reproductions/finetuning/cal_ql_finetune.py +++ b/reproductions/finetuning/cal_ql_finetune.py @@ -59,7 +59,9 @@ def main() -> None: n_steps=1000000, n_steps_per_epoch=1000, save_interval=10, - evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, + evaluators={ + "environment": d3rlpy.metrics.EnvironmentEvaluator(env, n_trials=20) + }, experiment_name=f"CalQL_pretraining_{args.dataset}_{args.seed}", ) @@ -68,6 +70,7 @@ def main() -> None: limit=1000000, env=env, transition_picker=transition_picker, + write_at_termination=True, ) # sample half from offline dataset and the rest from online buffer @@ -90,6 +93,7 @@ def main() -> None: n_updates=1000, update_interval=1000, save_interval=10, + eval_n_trials=20, )