diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index 51e38161..daea46db 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -378,6 +378,7 @@ def fit( experiment_name: Optional[str] = None, with_timestamp: bool = True, logging_steps: int = 500, + gradient_logging_steps: Optional[int] = None, logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH, logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), show_progress: bool = True, @@ -403,6 +404,7 @@ def fit( directory name. logging_steps: Number of steps to log metrics. This will be ignored if logging_strategy is EPOCH. + gradient_logging_steps: Number of steps to log gradients. logging_strategy: Logging strategy to use. logger_adapter: LoggerAdapterFactory object. show_progress: Flag to show progress bar for iterations. @@ -425,6 +427,7 @@ def fit( experiment_name=experiment_name, with_timestamp=with_timestamp, logging_steps=logging_steps, + gradient_logging_steps=gradient_logging_steps, logging_strategy=logging_strategy, logger_adapter=logger_adapter, show_progress=show_progress, @@ -442,6 +445,7 @@ def fitter( n_steps: int, n_steps_per_epoch: int = 10000, logging_steps: int = 500, + gradient_logging_steps: Optional[int] = None, logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH, experiment_name: Optional[str] = None, with_timestamp: bool = True, @@ -471,7 +475,8 @@ def fitter( with_timestamp: Flag to add timestamp string to the last of directory name. logging_steps: Number of steps to log metrics. This will be ignored - if loggig_strategy is EPOCH. + if logging_strategy is EPOCH. + gradient_logging_steps: Number of steps to log gradients. logging_strategy: Logging strategy to use. logger_adapter: LoggerAdapterFactory object. show_progress: Flag to show progress bar for iterations. @@ -520,6 +525,8 @@ def fitter( # save hyperparameters save_config(self, logger) + logger.watch_model(0, 0, gradient_logging_steps, self) # type: ignore + # training loop n_epochs = n_steps // n_steps_per_epoch total_step = 0 @@ -559,6 +566,8 @@ def fitter( total_step += 1 + logger.watch_model(epoch, total_step, gradient_logging_steps, self) # type: ignore + if ( logging_strategy == LoggingStrategy.STEPS and total_step % logging_steps == 0 @@ -608,6 +617,7 @@ def fit_online( experiment_name: Optional[str] = None, with_timestamp: bool = True, logging_steps: int = 500, + gradient_logging_steps: Optional[int] = None, logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH, logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), show_progress: bool = True, @@ -636,6 +646,7 @@ def fit_online( directory name. logging_steps: Number of steps to log metrics. This will be ignored if logging_strategy is EPOCH. + gradient_logging_steps: Number of steps to log gradients. logging_strategy: Logging strategy to use. logger_adapter: LoggerAdapterFactory object. show_progress: Flag to show progress bar for iterations. @@ -673,6 +684,8 @@ def fit_online( # save hyperparameters save_config(self, logger) + logger.watch_model(0, 0, gradient_logging_steps, self) # type: ignore + # switch based on show_progress flag xrange = trange if show_progress else range @@ -741,6 +754,8 @@ def fit_online( for name, val in loss.items(): logger.add_metric(name, val) + logger.watch_model(epoch, total_step, gradient_logging_steps, self) # type: ignore + if ( logging_strategy == LoggingStrategy.STEPS and total_step % logging_steps == 0 diff --git a/d3rlpy/logging/file_adapter.py b/d3rlpy/logging/file_adapter.py index 21439294..fbcf19a8 100644 --- a/d3rlpy/logging/file_adapter.py +++ b/d3rlpy/logging/file_adapter.py @@ -1,11 +1,17 @@ import json import os from enum import Enum, IntEnum -from typing import Any, Dict +from typing import Any, Dict, Optional import numpy as np -from .logger import LOG, LoggerAdapter, LoggerAdapterFactory, SaveProtocol +from .logger import ( + LOG, + LoggerAdapter, + LoggerAdapterFactory, + SaveProtocol, + TorchModuleProtocol, +) __all__ = ["FileAdapter", "FileAdapterFactory"] @@ -76,6 +82,25 @@ def close(self) -> None: def logdir(self) -> str: return self._logdir + def watch_model( + self, + epoch: int, + step: int, + logging_steps: Optional[int], + algo: TorchModuleProtocol, + ) -> None: + if logging_steps is not None and step % logging_steps == 0: + for name, grad in algo.impl.modules.get_gradients(): + path = os.path.join(self._logdir, f"{name}_grad.csv") + with open(path, "a") as f: + min_grad = grad.min() + max_grad = grad.max() + mean_grad = grad.mean() + print( + f"{epoch},{step},{name},{min_grad},{max_grad},{mean_grad}", + file=f, + ) + class FileAdapterFactory(LoggerAdapterFactory): r"""FileAdapterFactory class. diff --git a/d3rlpy/logging/logger.py b/d3rlpy/logging/logger.py index 48b8cd0d..ba289831 100644 --- a/d3rlpy/logging/logger.py +++ b/d3rlpy/logging/logger.py @@ -2,11 +2,14 @@ from collections import defaultdict from contextlib import contextmanager from datetime import datetime -from typing import Any, DefaultDict, Dict, Iterator, List +from typing import Any, DefaultDict, Dict, Iterator, List, Optional, Tuple import structlog +from torch import nn from typing_extensions import Protocol +from ..types import Float32NDArray + __all__ = [ "LOG", "set_log_context", @@ -39,6 +42,19 @@ class SaveProtocol(Protocol): def save(self, fname: str) -> None: ... +class ModuleProtocol(Protocol): + def get_torch_modules(self) -> Dict[str, nn.Module]: ... + def get_gradients(self) -> Iterator[Tuple[str, Float32NDArray]]: ... + + +class ImplProtocol(Protocol): + modules: ModuleProtocol + + +class TorchModuleProtocol(Protocol): + impl: ImplProtocol + + class LoggerAdapter(Protocol): r"""Interface of LoggerAdapter.""" @@ -88,6 +104,22 @@ def save_model(self, epoch: int, algo: SaveProtocol) -> None: def close(self) -> None: r"""Closes this LoggerAdapter.""" + def watch_model( + self, + epoch: int, + step: int, + logging_steps: Optional[int], + algo: TorchModuleProtocol, + ) -> None: + r"""Watch model parameters / gradients during training. + + Args: + epoch: Epoch. + step: Training step. + logging_steps: Training step. + algo: Algorithm. + """ + class LoggerAdapterFactory(Protocol): r"""Interface of LoggerAdapterFactory.""" @@ -171,3 +203,12 @@ def measure_time(self, name: str) -> Iterator[None]: @property def adapter(self) -> LoggerAdapter: return self._adapter + + def watch_model( + self, + epoch: int, + step: int, + logging_steps: Optional[int], + algo: TorchModuleProtocol, + ) -> None: + self._adapter.watch_model(epoch, step, logging_steps, algo) diff --git a/d3rlpy/logging/noop_adapter.py b/d3rlpy/logging/noop_adapter.py index 574ba3a9..aa4ece67 100644 --- a/d3rlpy/logging/noop_adapter.py +++ b/d3rlpy/logging/noop_adapter.py @@ -1,6 +1,11 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional -from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol +from .logger import ( + LoggerAdapter, + LoggerAdapterFactory, + SaveProtocol, + TorchModuleProtocol, +) __all__ = ["NoopAdapter", "NoopAdapterFactory"] @@ -32,6 +37,15 @@ def save_model(self, epoch: int, algo: SaveProtocol) -> None: def close(self) -> None: pass + def watch_model( + self, + epoch: int, + step: int, + logging_steps: Optional[int], + algo: TorchModuleProtocol, + ) -> None: + pass + class NoopAdapterFactory(LoggerAdapterFactory): r"""NoopAdapterFactory class. diff --git a/d3rlpy/logging/tensorboard_adapter.py b/d3rlpy/logging/tensorboard_adapter.py index 5b7e623a..28ec646f 100644 --- a/d3rlpy/logging/tensorboard_adapter.py +++ b/d3rlpy/logging/tensorboard_adapter.py @@ -1,9 +1,14 @@ import os -from typing import Any, Dict +from typing import Any, Dict, Optional import numpy as np -from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol +from .logger import ( + LoggerAdapter, + LoggerAdapterFactory, + SaveProtocol, + TorchModuleProtocol, +) __all__ = ["TensorboardAdapter", "TensorboardAdapterFactory"] @@ -64,6 +69,19 @@ def save_model(self, epoch: int, algo: SaveProtocol) -> None: def close(self) -> None: self._writer.close() + def watch_model( + self, + epoch: int, + step: int, + logging_steps: Optional[int], + algo: TorchModuleProtocol, + ) -> None: + if logging_steps is not None and step % logging_steps == 0: + for name, grad in algo.impl.modules.get_gradients(): + self._writer.add_histogram( + f"histograms/{name}_grad", grad, epoch + ) + class TensorboardAdapterFactory(LoggerAdapterFactory): r"""TensorboardAdapterFactory class. diff --git a/d3rlpy/logging/utils.py b/d3rlpy/logging/utils.py index 48bb77e1..467d1e33 100644 --- a/d3rlpy/logging/utils.py +++ b/d3rlpy/logging/utils.py @@ -1,6 +1,11 @@ -from typing import Any, Dict, Sequence +from typing import Any, Dict, Optional, Sequence -from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol +from .logger import ( + LoggerAdapter, + LoggerAdapterFactory, + SaveProtocol, + TorchModuleProtocol, +) __all__ = ["CombineAdapter", "CombineAdapterFactory"] @@ -44,6 +49,16 @@ def close(self) -> None: for adapter in self._adapters: adapter.close() + def watch_model( + self, + epoch: int, + step: int, + logging_steps: Optional[int], + algo: TorchModuleProtocol, + ) -> None: + for adapter in self._adapters: + adapter.watch_model(epoch, step, logging_steps, algo) + class CombineAdapterFactory(LoggerAdapterFactory): r"""CombineAdapterFactory class. diff --git a/d3rlpy/logging/wandb_adapter.py b/d3rlpy/logging/wandb_adapter.py index 616b6850..08813eb2 100644 --- a/d3rlpy/logging/wandb_adapter.py +++ b/d3rlpy/logging/wandb_adapter.py @@ -1,6 +1,11 @@ from typing import Any, Dict, Optional -from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol +from .logger import ( + LoggerAdapter, + LoggerAdapterFactory, + SaveProtocol, + TorchModuleProtocol, +) __all__ = ["WanDBAdapter", "WanDBAdapterFactory"] @@ -24,6 +29,7 @@ def __init__( except ImportError as e: raise ImportError("Please install wandb") from e self.run = wandb.init(project=project, name=experiment_name) + self._is_model_watched = False def write_params(self, params: Dict[str, Any]) -> None: """Writes hyperparameters to WandB config.""" @@ -52,6 +58,21 @@ def close(self) -> None: """Closes the logger and finishes the WandB run.""" self.run.finish() + def watch_model( + self, + epoch: int, + step: int, + logging_steps: Optional[int], + algo: TorchModuleProtocol, + ) -> None: + if not self._is_model_watched: + self.run.watch( + tuple(algo.impl.modules.get_torch_modules().values()), + log="gradients", + log_freq=logging_steps, + ) + self._is_model_watched = True + class WanDBAdapterFactory(LoggerAdapterFactory): r"""WandB Logger Adapter Factory class. diff --git a/d3rlpy/torch_utility.py b/d3rlpy/torch_utility.py index 2b231784..ab47d3ea 100644 --- a/d3rlpy/torch_utility.py +++ b/d3rlpy/torch_utility.py @@ -4,8 +4,10 @@ Any, BinaryIO, Dict, + Iterator, Optional, Sequence, + Tuple, TypeVar, Union, overload, @@ -388,6 +390,19 @@ def reset_optimizer_states(self) -> None: if isinstance(v, torch.optim.Optimizer): v.state = collections.defaultdict(dict) + def get_torch_modules(self) -> Dict[str, nn.Module]: + torch_modules: Dict[str, nn.Module] = {} + for k, v in asdict_without_copy(self).items(): + if isinstance(v, nn.Module): + torch_modules[k] = v + return torch_modules + + def get_gradients(self) -> Iterator[Tuple[str, Float32NDArray]]: + for module_name, module in self.get_torch_modules().items(): + for name, parameter in module.named_parameters(): + if parameter.requires_grad and parameter.grad is not None: + yield f"{module_name}.{name}", parameter.grad.cpu().detach().numpy() + TCallable = TypeVar("TCallable") diff --git a/tests/logging/test_logger.py b/tests/logging/test_logger.py index 55ca4443..c9584682 100644 --- a/tests/logging/test_logger.py +++ b/tests/logging/test_logger.py @@ -1,9 +1,10 @@ -from typing import Any, Dict +from typing import Any, Dict, List import pytest +from torch import nn from d3rlpy.logging import D3RLPyLogger -from d3rlpy.logging.logger import SaveProtocol +from d3rlpy.logging.logger import SaveProtocol, TorchModuleProtocol class StubLoggerAdapter: @@ -15,6 +16,7 @@ def __init__(self, experiment_name: str): self.is_after_write_metric_called = False self.is_save_model_called = False self.is_close_called = False + self.is_watch_model_called = False def write_params(self, params: Dict[str, Any]) -> None: self.is_write_params_called = True @@ -39,20 +41,40 @@ def save_model(self, epoch: int, algo: SaveProtocol) -> None: def close(self) -> None: self.is_close_called = True + def watch_model( + self, + epoch: int, + step: int, + logging_step: int, + algo: TorchModuleProtocol, + ) -> None: + self.is_watch_model_called = True + class StubLoggerAdapterFactory: def create(self, experiment_name: str) -> StubLoggerAdapter: return StubLoggerAdapter(experiment_name) +class StubModules: + def get_torch_modules(self) -> List[nn.Module]: + return [] + + +class StubImpl: + modules: StubModules + + class StubAlgo: + impl: StubImpl + def save(self, fname: str) -> None: pass @pytest.mark.parametrize("with_timestamp", [False, True]) def test_d3rlpy_logger(with_timestamp: bool) -> None: - logger = D3RLPyLogger(StubLoggerAdapterFactory(), "test", with_timestamp) + logger = D3RLPyLogger(StubLoggerAdapterFactory(), "test", with_timestamp) # type: ignore # check experiment_name adapter = logger.adapter @@ -87,3 +109,7 @@ def test_d3rlpy_logger(with_timestamp: bool) -> None: assert not adapter.is_close_called logger.close() assert adapter.is_close_called + + assert not adapter.is_watch_model_called + logger.watch_model(1, 1, 1, StubAlgo()) + assert adapter.is_watch_model_called