Skip to content

Commit

Permalink
Resolve typying issue with gradient metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Oct 15, 2024
1 parent 3d51ee7 commit 7951f3d
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 21 deletions.
12 changes: 8 additions & 4 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def fitter(
# save hyperparameters
save_config(self, logger)

logger.watch_model(0, 0, gradient_logging_steps, self) # type: ignore
logger.watch_model(0, 0, gradient_logging_steps, self)

# training loop
n_epochs = n_steps // n_steps_per_epoch
Expand Down Expand Up @@ -566,7 +566,9 @@ def fitter(

total_step += 1

logger.watch_model(epoch, total_step, gradient_logging_steps, self) # type: ignore
logger.watch_model(
epoch, total_step, gradient_logging_steps, self
)

if (
logging_strategy == LoggingStrategy.STEPS
Expand Down Expand Up @@ -684,7 +686,7 @@ def fit_online(
# save hyperparameters
save_config(self, logger)

logger.watch_model(0, 0, gradient_logging_steps, self) # type: ignore
logger.watch_model(0, 0, gradient_logging_steps, self)

# switch based on show_progress flag
xrange = trange if show_progress else range
Expand Down Expand Up @@ -754,7 +756,9 @@ def fit_online(
for name, val in loss.items():
logger.add_metric(name, val)

logger.watch_model(epoch, total_step, gradient_logging_steps, self) # type: ignore
logger.watch_model(
epoch, total_step, gradient_logging_steps, self
)

if (
logging_strategy == LoggingStrategy.STEPS
Expand Down
5 changes: 3 additions & 2 deletions d3rlpy/logging/file_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from .logger import (
LOG,
AlgProtocol,
LoggerAdapter,
LoggerAdapterFactory,
SaveProtocol,
TorchModuleProtocol,
)

__all__ = ["FileAdapter", "FileAdapterFactory"]
Expand Down Expand Up @@ -87,8 +87,9 @@ def watch_model(
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
algo: AlgProtocol,
) -> None:
assert algo.impl
if logging_steps is not None and step % logging_steps == 0:
for name, grad in algo.impl.modules.get_gradients():
path = os.path.join(self._logdir, f"{name}_grad.csv")
Expand Down
14 changes: 9 additions & 5 deletions d3rlpy/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,15 @@ def get_gradients(self) -> Iterator[Tuple[str, Float32NDArray]]: ...


class ImplProtocol(Protocol):
modules: ModuleProtocol
@property
def modules(self) -> ModuleProtocol:
raise NotImplementedError


class TorchModuleProtocol(Protocol):
impl: ImplProtocol
class AlgProtocol(Protocol):
@property
def impl(self) -> Optional[ImplProtocol]:
raise NotImplementedError


class LoggerAdapter(Protocol):
Expand Down Expand Up @@ -109,7 +113,7 @@ def watch_model(
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
algo: AlgProtocol,
) -> None:
r"""Watch model parameters / gradients during training.
Expand Down Expand Up @@ -209,6 +213,6 @@ def watch_model(
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
algo: AlgProtocol,
) -> None:
self._adapter.watch_model(epoch, step, logging_steps, algo)
4 changes: 2 additions & 2 deletions d3rlpy/logging/noop_adapter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Dict, Optional

from .logger import (
AlgProtocol,
LoggerAdapter,
LoggerAdapterFactory,
SaveProtocol,
TorchModuleProtocol,
)

__all__ = ["NoopAdapter", "NoopAdapterFactory"]
Expand Down Expand Up @@ -42,7 +42,7 @@ def watch_model(
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
algo: AlgProtocol,
) -> None:
pass

Expand Down
5 changes: 3 additions & 2 deletions d3rlpy/logging/tensorboard_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import numpy as np

from .logger import (
AlgProtocol,
LoggerAdapter,
LoggerAdapterFactory,
SaveProtocol,
TorchModuleProtocol,
)

__all__ = ["TensorboardAdapter", "TensorboardAdapterFactory"]
Expand Down Expand Up @@ -74,8 +74,9 @@ def watch_model(
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
algo: AlgProtocol,
) -> None:
assert algo.impl
if logging_steps is not None and step % logging_steps == 0:
for name, grad in algo.impl.modules.get_gradients():
self._writer.add_histogram(
Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/logging/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Dict, Optional, Sequence

from .logger import (
AlgProtocol,
LoggerAdapter,
LoggerAdapterFactory,
SaveProtocol,
TorchModuleProtocol,
)

__all__ = ["CombineAdapter", "CombineAdapterFactory"]
Expand Down Expand Up @@ -54,7 +54,7 @@ def watch_model(
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
algo: AlgProtocol,
) -> None:
for adapter in self._adapters:
adapter.watch_model(epoch, step, logging_steps, algo)
Expand Down
5 changes: 3 additions & 2 deletions d3rlpy/logging/wandb_adapter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Dict, Optional

from .logger import (
AlgProtocol,
LoggerAdapter,
LoggerAdapterFactory,
SaveProtocol,
TorchModuleProtocol,
)

__all__ = ["WanDBAdapter", "WanDBAdapterFactory"]
Expand Down Expand Up @@ -63,9 +63,10 @@ def watch_model(
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
algo: AlgProtocol,
) -> None:
if not self._is_model_watched:
assert algo.impl
self.run.watch(
tuple(algo.impl.modules.get_torch_modules().values()),
log="gradients",
Expand Down
4 changes: 2 additions & 2 deletions tests/logging/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn

from d3rlpy.logging import D3RLPyLogger
from d3rlpy.logging.logger import SaveProtocol, TorchModuleProtocol
from d3rlpy.logging.logger import AlgProtocol, SaveProtocol


class StubLoggerAdapter:
Expand Down Expand Up @@ -46,7 +46,7 @@ def watch_model(
epoch: int,
step: int,
logging_step: int,
algo: TorchModuleProtocol,
algo: AlgProtocol,
) -> None:
self.is_watch_model_called = True

Expand Down

0 comments on commit 7951f3d

Please sign in to comment.