diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index daea46db..e2517961 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -525,7 +525,7 @@ def fitter( # save hyperparameters save_config(self, logger) - logger.watch_model(0, 0, gradient_logging_steps, self) # type: ignore + logger.watch_model(0, 0, gradient_logging_steps, self) # training loop n_epochs = n_steps // n_steps_per_epoch @@ -566,7 +566,9 @@ def fitter( total_step += 1 - logger.watch_model(epoch, total_step, gradient_logging_steps, self) # type: ignore + logger.watch_model( + epoch, total_step, gradient_logging_steps, self + ) if ( logging_strategy == LoggingStrategy.STEPS @@ -684,7 +686,7 @@ def fit_online( # save hyperparameters save_config(self, logger) - logger.watch_model(0, 0, gradient_logging_steps, self) # type: ignore + logger.watch_model(0, 0, gradient_logging_steps, self) # switch based on show_progress flag xrange = trange if show_progress else range @@ -754,7 +756,9 @@ def fit_online( for name, val in loss.items(): logger.add_metric(name, val) - logger.watch_model(epoch, total_step, gradient_logging_steps, self) # type: ignore + logger.watch_model( + epoch, total_step, gradient_logging_steps, self + ) if ( logging_strategy == LoggingStrategy.STEPS diff --git a/d3rlpy/logging/file_adapter.py b/d3rlpy/logging/file_adapter.py index fbcf19a8..cc87d8a4 100644 --- a/d3rlpy/logging/file_adapter.py +++ b/d3rlpy/logging/file_adapter.py @@ -7,10 +7,10 @@ from .logger import ( LOG, + AlgProtocol, LoggerAdapter, LoggerAdapterFactory, SaveProtocol, - TorchModuleProtocol, ) __all__ = ["FileAdapter", "FileAdapterFactory"] @@ -87,8 +87,9 @@ def watch_model( epoch: int, step: int, logging_steps: Optional[int], - algo: TorchModuleProtocol, + algo: AlgProtocol, ) -> None: + assert algo.impl if logging_steps is not None and step % logging_steps == 0: for name, grad in algo.impl.modules.get_gradients(): path = os.path.join(self._logdir, f"{name}_grad.csv") diff --git a/d3rlpy/logging/logger.py b/d3rlpy/logging/logger.py index ba289831..a0cb152f 100644 --- a/d3rlpy/logging/logger.py +++ b/d3rlpy/logging/logger.py @@ -48,11 +48,15 @@ def get_gradients(self) -> Iterator[Tuple[str, Float32NDArray]]: ... class ImplProtocol(Protocol): - modules: ModuleProtocol + @property + def modules(self) -> ModuleProtocol: + raise NotImplementedError -class TorchModuleProtocol(Protocol): - impl: ImplProtocol +class AlgProtocol(Protocol): + @property + def impl(self) -> Optional[ImplProtocol]: + raise NotImplementedError class LoggerAdapter(Protocol): @@ -109,7 +113,7 @@ def watch_model( epoch: int, step: int, logging_steps: Optional[int], - algo: TorchModuleProtocol, + algo: AlgProtocol, ) -> None: r"""Watch model parameters / gradients during training. @@ -209,6 +213,6 @@ def watch_model( epoch: int, step: int, logging_steps: Optional[int], - algo: TorchModuleProtocol, + algo: AlgProtocol, ) -> None: self._adapter.watch_model(epoch, step, logging_steps, algo) diff --git a/d3rlpy/logging/noop_adapter.py b/d3rlpy/logging/noop_adapter.py index aa4ece67..de01eec0 100644 --- a/d3rlpy/logging/noop_adapter.py +++ b/d3rlpy/logging/noop_adapter.py @@ -1,10 +1,10 @@ from typing import Any, Dict, Optional from .logger import ( + AlgProtocol, LoggerAdapter, LoggerAdapterFactory, SaveProtocol, - TorchModuleProtocol, ) __all__ = ["NoopAdapter", "NoopAdapterFactory"] @@ -42,7 +42,7 @@ def watch_model( epoch: int, step: int, logging_steps: Optional[int], - algo: TorchModuleProtocol, + algo: AlgProtocol, ) -> None: pass diff --git a/d3rlpy/logging/tensorboard_adapter.py b/d3rlpy/logging/tensorboard_adapter.py index 28ec646f..ba248785 100644 --- a/d3rlpy/logging/tensorboard_adapter.py +++ b/d3rlpy/logging/tensorboard_adapter.py @@ -4,10 +4,10 @@ import numpy as np from .logger import ( + AlgProtocol, LoggerAdapter, LoggerAdapterFactory, SaveProtocol, - TorchModuleProtocol, ) __all__ = ["TensorboardAdapter", "TensorboardAdapterFactory"] @@ -74,8 +74,9 @@ def watch_model( epoch: int, step: int, logging_steps: Optional[int], - algo: TorchModuleProtocol, + algo: AlgProtocol, ) -> None: + assert algo.impl if logging_steps is not None and step % logging_steps == 0: for name, grad in algo.impl.modules.get_gradients(): self._writer.add_histogram( diff --git a/d3rlpy/logging/utils.py b/d3rlpy/logging/utils.py index 467d1e33..81412395 100644 --- a/d3rlpy/logging/utils.py +++ b/d3rlpy/logging/utils.py @@ -1,10 +1,10 @@ from typing import Any, Dict, Optional, Sequence from .logger import ( + AlgProtocol, LoggerAdapter, LoggerAdapterFactory, SaveProtocol, - TorchModuleProtocol, ) __all__ = ["CombineAdapter", "CombineAdapterFactory"] @@ -54,7 +54,7 @@ def watch_model( epoch: int, step: int, logging_steps: Optional[int], - algo: TorchModuleProtocol, + algo: AlgProtocol, ) -> None: for adapter in self._adapters: adapter.watch_model(epoch, step, logging_steps, algo) diff --git a/d3rlpy/logging/wandb_adapter.py b/d3rlpy/logging/wandb_adapter.py index 08813eb2..4e388a26 100644 --- a/d3rlpy/logging/wandb_adapter.py +++ b/d3rlpy/logging/wandb_adapter.py @@ -1,10 +1,10 @@ from typing import Any, Dict, Optional from .logger import ( + AlgProtocol, LoggerAdapter, LoggerAdapterFactory, SaveProtocol, - TorchModuleProtocol, ) __all__ = ["WanDBAdapter", "WanDBAdapterFactory"] @@ -63,9 +63,10 @@ def watch_model( epoch: int, step: int, logging_steps: Optional[int], - algo: TorchModuleProtocol, + algo: AlgProtocol, ) -> None: if not self._is_model_watched: + assert algo.impl self.run.watch( tuple(algo.impl.modules.get_torch_modules().values()), log="gradients", diff --git a/tests/logging/test_logger.py b/tests/logging/test_logger.py index c9584682..9b3ec0f0 100644 --- a/tests/logging/test_logger.py +++ b/tests/logging/test_logger.py @@ -4,7 +4,7 @@ from torch import nn from d3rlpy.logging import D3RLPyLogger -from d3rlpy.logging.logger import SaveProtocol, TorchModuleProtocol +from d3rlpy.logging.logger import AlgProtocol, SaveProtocol class StubLoggerAdapter: @@ -46,7 +46,7 @@ def watch_model( epoch: int, step: int, logging_step: int, - algo: TorchModuleProtocol, + algo: AlgProtocol, ) -> None: self.is_watch_model_called = True