diff --git a/d3rlpy/logging/logger.py b/d3rlpy/logging/logger.py index 61ccdad5..3792c947 100644 --- a/d3rlpy/logging/logger.py +++ b/d3rlpy/logging/logger.py @@ -90,12 +90,11 @@ def write_histogram( ) -> None: r"""Writes histogram. - # TODO Args: - epoch: - step: - name: - values: + epoch: Epoch. + step: Training step. + name: Histogram name. + values: Histogram values. """ def after_write_metric(self, epoch: int, step: int) -> None: @@ -120,7 +119,12 @@ def close(self) -> None: def watch_model( self, logging_steps: int, algo: TorchModuleProtocol ) -> None: - r"""TODO: Docstring and type""" + r"""Watch model parameters / gradients during training. + + Args: + logging_steps: Training step. + algo: Algorithm. + """ class LoggerAdapterFactory(Protocol): @@ -214,5 +218,5 @@ def measure_time(self, name: str) -> Iterator[None]: def adapter(self) -> LoggerAdapter: return self._adapter - def watch_model(self, logging_steps, algo: TorchModuleProtocol) -> None: + def watch_model(self, logging_steps: int, algo: TorchModuleProtocol) -> None: self._adapter.watch_model(logging_steps, algo) diff --git a/tests/logging/test_logger.py b/tests/logging/test_logger.py index 55ca4443..ec4dd99e 100644 --- a/tests/logging/test_logger.py +++ b/tests/logging/test_logger.py @@ -1,9 +1,11 @@ from typing import Any, Dict +import numpy as np import pytest from d3rlpy.logging import D3RLPyLogger from d3rlpy.logging.logger import SaveProtocol +from d3rlpy.types import Float32NDArray class StubLoggerAdapter: @@ -12,6 +14,7 @@ def __init__(self, experiment_name: str): self.is_write_params_called = False self.is_before_write_metric_called = False self.is_write_metric_called = False + self.is_write_histogram_called = False self.is_after_write_metric_called = False self.is_save_model_called = False self.is_close_called = False @@ -28,6 +31,11 @@ def write_metric( assert self.is_before_write_metric_called self.is_write_metric_called = True + def write_histogram( + self, epoch: int, step: int, name: str, values: Float32NDArray + ) -> None: + self.is_write_histogram_called = True + def after_write_metric(self, epoch: int, step: int) -> None: assert self.is_before_write_metric_called assert self.is_write_metric_called @@ -67,17 +75,20 @@ def test_d3rlpy_logger(with_timestamp: bool) -> None: assert adapter.is_write_params_called logger.add_metric("test", 1) + logger.add_histogram("test", np.array([1.0], dtype=np.float32)) with logger.measure_time("test"): pass assert not adapter.is_before_write_metric_called assert not adapter.is_write_metric_called + assert not adapter.is_write_histogram_called assert not adapter.is_after_write_metric_called metrics = logger.commit(1, 1) assert "test" in metrics assert "time_test" in metrics assert adapter.is_before_write_metric_called assert adapter.is_write_metric_called + assert adapter.is_write_histogram_called assert adapter.is_after_write_metric_called assert not adapter.is_save_model_called