Skip to content

Commit

Permalink
try simplify the implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
hasan-yaman committed Oct 13, 2024
1 parent 502d526 commit 3cf717f
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 128 deletions.
33 changes: 9 additions & 24 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
convert_to_torch,
convert_to_torch_recursively,
eval_api,
get_gradients,
hard_sync,
sync_optimizer_state,
train_api,
Expand Down Expand Up @@ -526,9 +525,7 @@ def fitter(
# save hyperparameters
save_config(self, logger)

# watch model gradients
if gradient_logging_steps is not None:
logger.watch_model(gradient_logging_steps, self)
logger.watch_model(0, 0, gradient_logging_steps, self)

# training loop
n_epochs = n_steps // n_steps_per_epoch
Expand Down Expand Up @@ -569,14 +566,9 @@ def fitter(

total_step += 1

if (
gradient_logging_steps is not None
and total_step % gradient_logging_steps == 0
):
for name, grad in get_gradients(
self.impl.modules.get_torch_modules()
):
logger.add_histogram(name=name, values=grad)
logger.watch_model(
epoch, total_step, gradient_logging_steps, self
)

if (
logging_strategy == LoggingStrategy.STEPS
Expand Down Expand Up @@ -694,9 +686,7 @@ def fit_online(
# save hyperparameters
save_config(self, logger)

# watch model gradients
if gradient_logging_steps is not None:
logger.watch_model(gradient_logging_steps, self)
logger.watch_model(0, 0, gradient_logging_steps, self)

# switch based on show_progress flag
xrange = trange if show_progress else range
Expand Down Expand Up @@ -766,21 +756,16 @@ def fit_online(
for name, val in loss.items():
logger.add_metric(name, val)

logger.watch_model(
epoch, total_step, gradient_logging_steps, self
)

if (
logging_strategy == LoggingStrategy.STEPS
and total_step % logging_steps == 0
):
logger.commit(epoch, total_step)

if (
gradient_logging_steps is not None
and total_step % gradient_logging_steps == 0
):
for name, grad in get_gradients(
self.impl.modules.get_torch_modules()
):
logger.add_histogram(name=name, values=grad)

# call callback if given
if callback:
callback(self, epoch, total_step)
Expand Down
34 changes: 17 additions & 17 deletions d3rlpy/logging/file_adapter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
import os
from enum import Enum, IntEnum
from typing import Any, Dict
from typing import Any, Dict, Optional

import numpy as np

from ..types import Float32NDArray
from .logger import (
LOG,
LoggerAdapter,
Expand Down Expand Up @@ -67,19 +66,6 @@ def write_metric(
with open(path, "a") as f:
print(f"{epoch},{step},{value}", file=f)

def write_histogram(
self, epoch: int, step: int, name: str, value: Float32NDArray
) -> None:
path = os.path.join(self._logdir, f"{name}.csv")
with open(path, "a") as f:
min_value = value.min()
max_value = value.max()
mean_value = value.mean()
print(
f"{epoch},{step},{name},{min_value},{max_value},{mean_value}",
file=f,
)

def after_write_metric(self, epoch: int, step: int) -> None:
pass

Expand All @@ -97,9 +83,23 @@ def logdir(self) -> str:
return self._logdir

def watch_model(
self, logging_steps: int, algo: TorchModuleProtocol
self,
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
) -> None:
pass
if logging_steps is not None and step % logging_steps == 0:
for name, grad in algo.impl.modules.get_gradients():
path = os.path.join(self._logdir, f"{name}.csv")
with open(path, "a") as f:
min_grad = grad.min()
max_grad = grad.max()
mean_grad = grad.mean()
print(
f"{epoch},{step},{name},{min_grad},{max_grad},{mean_grad}",
file=f,
)


class FileAdapterFactory(LoggerAdapterFactory):
Expand Down
42 changes: 17 additions & 25 deletions d3rlpy/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime
from typing import Any, DefaultDict, Dict, Iterator, List
from typing import Any, DefaultDict, Dict, Iterator, List, Optional, Tuple

import numpy as np
import structlog
from torch import nn
from typing_extensions import Protocol
Expand Down Expand Up @@ -45,6 +44,7 @@ def save(self, fname: str) -> None: ...

class ModuleProtocol(Protocol):
def get_torch_modules(self) -> List[nn.Module]: ...
def get_gradients(self) -> Iterator[Tuple[str, Float32NDArray]]: ...


class ImplProtocol(Protocol):
Expand Down Expand Up @@ -85,18 +85,6 @@ def write_metric(
value: Metric value.
"""

def write_histogram(
self, epoch: int, step: int, name: str, values: Float32NDArray
) -> None:
r"""Writes histogram.
Args:
epoch: Epoch.
step: Training step.
name: Histogram name.
values: Histogram values.
"""

def after_write_metric(self, epoch: int, step: int) -> None:
r"""Callback executed after write_metric method.
Expand All @@ -117,11 +105,17 @@ def close(self) -> None:
r"""Closes this LoggerAdapter."""

def watch_model(
self, logging_steps: int, algo: TorchModuleProtocol
self,
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
) -> None:
r"""Watch model parameters / gradients during training.
Args:
epoch: Epoch.
step: Training step.
logging_steps: Training step.
algo: Algorithm.
"""
Expand Down Expand Up @@ -161,7 +155,6 @@ def __init__(
self._experiment_name = experiment_name
self._adapter = adapter_factory.create(self._experiment_name)
self._metrics_buffer = defaultdict(list)
self._histogram_metrics_buffer = defaultdict(list)

def add_params(self, params: Dict[str, Any]) -> None:
self._adapter.write_params(params)
Expand All @@ -170,9 +163,6 @@ def add_params(self, params: Dict[str, Any]) -> None:
def add_metric(self, name: str, value: float) -> None:
self._metrics_buffer[name].append(value)

def add_histogram(self, name: str, values: Float32NDArray) -> None:
self._histogram_metrics_buffer[name].append(values)

def commit(self, epoch: int, step: int) -> Dict[str, float]:
self._adapter.before_write_metric(epoch, step)

Expand All @@ -182,10 +172,6 @@ def commit(self, epoch: int, step: int) -> Dict[str, float]:
self._adapter.write_metric(epoch, step, name, metric)
metrics[name] = metric

for name, buffer in self._histogram_metrics_buffer.items():
histogram_values = np.concatenate(buffer)
self._adapter.write_histogram(epoch, step, name, histogram_values)

LOG.info(
f"{self._experiment_name}: epoch={epoch} step={step}",
epoch=epoch,
Expand Down Expand Up @@ -218,5 +204,11 @@ def measure_time(self, name: str) -> Iterator[None]:
def adapter(self) -> LoggerAdapter:
return self._adapter

def watch_model(self, logging_steps: int, algo: TorchModuleProtocol) -> None:
self._adapter.watch_model(logging_steps, algo)
def watch_model(
self,
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
) -> None:
self._adapter.watch_model(epoch, step, logging_steps, algo)
14 changes: 6 additions & 8 deletions d3rlpy/logging/noop_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Dict
from typing import Any, Dict, Optional

from ..types import Float32NDArray
from .logger import (
LoggerAdapter,
LoggerAdapterFactory,
Expand Down Expand Up @@ -29,11 +28,6 @@ def write_metric(
) -> None:
pass

def write_histogram(
self, epoch: int, step: int, name: str, values: Float32NDArray
) -> None:
pass

def after_write_metric(self, epoch: int, step: int) -> None:
pass

Expand All @@ -44,7 +38,11 @@ def close(self) -> None:
pass

def watch_model(
self, logging_steps: int, algo: TorchModuleProtocol
self,
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
) -> None:
pass

Expand Down
20 changes: 11 additions & 9 deletions d3rlpy/logging/tensorboard_adapter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import os
from typing import Any, Dict
from typing import Any, Dict, Optional

import numpy as np

from ..types import Float32NDArray
from .logger import (
LoggerAdapter,
LoggerAdapterFactory,
Expand Down Expand Up @@ -56,11 +55,6 @@ def write_metric(
self._writer.add_scalar(f"metrics/{name}", value, epoch)
self._metrics[name] = value

def write_histogram(
self, epoch: int, step: int, name: str, value: Float32NDArray
) -> None:
self._writer.add_histogram(f"histograms/{name}_grad", value, epoch)

def after_write_metric(self, epoch: int, step: int) -> None:
self._writer.add_hparams(
self._params,
Expand All @@ -76,9 +70,17 @@ def close(self) -> None:
self._writer.close()

def watch_model(
self, logging_steps: int, algo: TorchModuleProtocol
self,
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
) -> None:
pass
if logging_steps is not None and step % logging_steps == 0:
for name, grad in algo.impl.modules.get_gradients():
self._writer.add_histogram(
f"histograms/{name}_grad", grad, epoch
)


class TensorboardAdapterFactory(LoggerAdapterFactory):
Expand Down
17 changes: 7 additions & 10 deletions d3rlpy/logging/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Dict, Sequence
from typing import Any, Dict, Optional, Sequence

from ..types import Float32NDArray
from .logger import (
LoggerAdapter,
LoggerAdapterFactory,
Expand Down Expand Up @@ -38,12 +37,6 @@ def write_metric(
for adapter in self._adapters:
adapter.write_metric(epoch, step, name, value)

def write_histogram(
self, epoch: int, step: int, name: str, values: Float32NDArray
) -> None:
for adapter in self._adapters:
adapter.write_histogram(epoch, step, name, values)

def after_write_metric(self, epoch: int, step: int) -> None:
for adapter in self._adapters:
adapter.after_write_metric(epoch, step)
Expand All @@ -57,10 +50,14 @@ def close(self) -> None:
adapter.close()

def watch_model(
self, logging_steps: int, algo: TorchModuleProtocol
self,
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
) -> None:
for adapter in self._adapters:
adapter.watch_model(logging_steps, algo)
adapter.watch_model(epoch, step, logging_steps, algo)


class CombineAdapterFactory(LoggerAdapterFactory):
Expand Down
25 changes: 13 additions & 12 deletions d3rlpy/logging/wandb_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Dict, Optional

from ..types import Float32NDArray
from .logger import (
LoggerAdapter,
LoggerAdapterFactory,
Expand Down Expand Up @@ -30,6 +29,7 @@ def __init__(
except ImportError as e:
raise ImportError("Please install wandb") from e
self.run = wandb.init(project=project, name=experiment_name)
self._is_model_watched = False

def write_params(self, params: Dict[str, Any]) -> None:
"""Writes hyperparameters to WandB config."""
Expand All @@ -44,11 +44,6 @@ def write_metric(
"""Writes metric to WandB."""
self.run.log({name: value, "epoch": epoch}, step=step)

def write_histogram(
self, epoch: int, step: int, name: str, values: Float32NDArray
) -> None:
pass

def after_write_metric(self, epoch: int, step: int) -> None:
"""Callback executed after writing metric."""

Expand All @@ -64,13 +59,19 @@ def close(self) -> None:
self.run.finish()

def watch_model(
self, logging_steps: int, algo: TorchModuleProtocol
self,
epoch: int,
step: int,
logging_steps: Optional[int],
algo: TorchModuleProtocol,
) -> None:
self.run.watch(
algo.impl.modules.get_torch_modules(),
log="gradients",
log_freq=logging_steps,
)
if not self._is_model_watched:
self.run.watch(
algo.impl.modules.get_torch_modules(),
log="gradients",
log_freq=logging_steps,
)
self._is_model_watched = True


class WanDBAdapterFactory(LoggerAdapterFactory):
Expand Down
Loading

0 comments on commit 3cf717f

Please sign in to comment.