diff --git a/d3rlpy/optimizers/optimizers.py b/d3rlpy/optimizers/optimizers.py index 2b965251..0852a467 100644 --- a/d3rlpy/optimizers/optimizers.py +++ b/d3rlpy/optimizers/optimizers.py @@ -5,6 +5,7 @@ from torch.optim import SGD, Adam, AdamW, Optimizer, RMSprop from torch.optim.lr_scheduler import LRScheduler +from ..logging import LOG from ..serializable_config import DynamicConfig, generate_config_registration from .lr_schedulers import LRSchedulerFactory, make_lr_scheduler_field @@ -102,9 +103,15 @@ def state_dict(self) -> Mapping[str, Any]: } def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: - self._optim.load_state_dict(state_dict["optim"]) + if "optim" in state_dict: + self._optim.load_state_dict(state_dict["optim"]) + else: + LOG.warning("Skip loading optimizer state.") if self._lr_scheduler: - self._lr_scheduler.load_state_dict(state_dict["lr_scheduler"]) + if "lr_scheduler" in state_dict: + self._lr_scheduler.load_state_dict(state_dict["lr_scheduler"]) + else: + LOG.warning("Skip loading lr scheduler state.") @dataclasses.dataclass()