Skip to content

Commit

Permalink
more updates
Browse files Browse the repository at this point in the history
  • Loading branch information
hasan-yaman committed Oct 12, 2024
1 parent a2e4431 commit 502d526
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
18 changes: 11 additions & 7 deletions d3rlpy/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,11 @@ def write_histogram(
) -> None:
r"""Writes histogram.
# TODO
Args:
epoch:
step:
name:
values:
epoch: Epoch.
step: Training step.
name: Histogram name.
values: Histogram values.
"""

def after_write_metric(self, epoch: int, step: int) -> None:
Expand All @@ -120,7 +119,12 @@ def close(self) -> None:
def watch_model(
self, logging_steps: int, algo: TorchModuleProtocol
) -> None:
r"""TODO: Docstring and type"""
r"""Watch model parameters / gradients during training.
Args:
logging_steps: Training step.
algo: Algorithm.
"""


class LoggerAdapterFactory(Protocol):
Expand Down Expand Up @@ -214,5 +218,5 @@ def measure_time(self, name: str) -> Iterator[None]:
def adapter(self) -> LoggerAdapter:
return self._adapter

def watch_model(self, logging_steps, algo: TorchModuleProtocol) -> None:
def watch_model(self, logging_steps: int, algo: TorchModuleProtocol) -> None:
self._adapter.watch_model(logging_steps, algo)
11 changes: 11 additions & 0 deletions tests/logging/test_logger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any, Dict

import numpy as np
import pytest

from d3rlpy.logging import D3RLPyLogger
from d3rlpy.logging.logger import SaveProtocol
from d3rlpy.types import Float32NDArray


class StubLoggerAdapter:
Expand All @@ -12,6 +14,7 @@ def __init__(self, experiment_name: str):
self.is_write_params_called = False
self.is_before_write_metric_called = False
self.is_write_metric_called = False
self.is_write_histogram_called = False
self.is_after_write_metric_called = False
self.is_save_model_called = False
self.is_close_called = False
Expand All @@ -28,6 +31,11 @@ def write_metric(
assert self.is_before_write_metric_called
self.is_write_metric_called = True

def write_histogram(
self, epoch: int, step: int, name: str, values: Float32NDArray
) -> None:
self.is_write_histogram_called = True

def after_write_metric(self, epoch: int, step: int) -> None:
assert self.is_before_write_metric_called
assert self.is_write_metric_called
Expand Down Expand Up @@ -67,17 +75,20 @@ def test_d3rlpy_logger(with_timestamp: bool) -> None:
assert adapter.is_write_params_called

logger.add_metric("test", 1)
logger.add_histogram("test", np.array([1.0], dtype=np.float32))
with logger.measure_time("test"):
pass

assert not adapter.is_before_write_metric_called
assert not adapter.is_write_metric_called
assert not adapter.is_write_histogram_called
assert not adapter.is_after_write_metric_called
metrics = logger.commit(1, 1)
assert "test" in metrics
assert "time_test" in metrics
assert adapter.is_before_write_metric_called
assert adapter.is_write_metric_called
assert adapter.is_write_histogram_called
assert adapter.is_after_write_metric_called

assert not adapter.is_save_model_called
Expand Down

0 comments on commit 502d526

Please sign in to comment.