Skip to content

Commit

Permalink
Add CudaGraphWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 2, 2024
1 parent 3b01da3 commit f6de602
Show file tree
Hide file tree
Showing 37 changed files with 455 additions and 152 deletions.
3 changes: 3 additions & 0 deletions d3rlpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def seed(n: int) -> None:
# run healthcheck
run_healthcheck()

# enable autograd compilation
torch._dynamo.config.compiled_autograd = True
torch.set_float32_matmul_precision("high")

# register Shimmy if available
try:
Expand Down
3 changes: 3 additions & 0 deletions d3rlpy/algos/qlearning/awac.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class AWACConfig(LearnableConfig):
n_action_samples (int): Number of sampled actions to calculate
:math:`A^\pi(s_t, a_t)`.
n_critics (int): Number of Q functions for ensemble.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 3e-4
Expand All @@ -85,6 +86,7 @@ class AWACConfig(LearnableConfig):
lam: float = 1.0
n_action_samples: int = 1
n_critics: int = 2
compile: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -158,6 +160,7 @@ def inner_create_impl(
tau=self._config.tau,
lam=self._config.lam,
n_action_samples=self._config.n_action_samples,
compile=self._config.compile and "cuda" in self._device,
device=self._device,
)

Expand Down
6 changes: 6 additions & 0 deletions d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class BCQConfig(LearnableConfig):
rl_start_step (int): Steps to start to update policy function and Q
functions. If this is large, RL training would be more stabilized.
beta (float): KL reguralization term for Conditional VAE.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 1e-3
Expand All @@ -159,6 +160,7 @@ class BCQConfig(LearnableConfig):
action_flexibility: float = 0.05
rl_start_step: int = 0
beta: float = 0.5
compile: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -264,6 +266,7 @@ def inner_create_impl(
action_flexibility=self._config.action_flexibility,
beta=self._config.beta,
rl_start_step=self._config.rl_start_step,
compile=self._config.compile and "cuda" in self._device,
device=self._device,
)

Expand Down Expand Up @@ -331,6 +334,7 @@ class DiscreteBCQConfig(LearnableConfig):
target_update_interval (int): Interval to update the target network.
share_encoder (bool): Flag to share encoder between Q-function and
imitation models.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
"""

learning_rate: float = 6.25e-5
Expand All @@ -344,6 +348,7 @@ class DiscreteBCQConfig(LearnableConfig):
beta: float = 0.5
target_update_interval: int = 8000
share_encoder: bool = True
compile: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -422,6 +427,7 @@ def inner_create_impl(
gamma=self._config.gamma,
action_flexibility=self._config.action_flexibility,
beta=self._config.beta,
compile=self._config.compile and "cuda" in self._device,
device=self._device,
)

Expand Down
3 changes: 3 additions & 0 deletions d3rlpy/algos/qlearning/bear.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class BEARConfig(LearnableConfig):
policy training.
warmup_steps (int): Number of steps to warmup the policy
function.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 1e-4
Expand Down Expand Up @@ -145,6 +146,7 @@ class BEARConfig(LearnableConfig):
mmd_sigma: float = 20.0
vae_kl_weight: float = 0.5
warmup_steps: int = 40000
compile: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -266,6 +268,7 @@ def inner_create_impl(
mmd_sigma=self._config.mmd_sigma,
vae_kl_weight=self._config.vae_kl_weight,
warmup_steps=self._config.warmup_steps,
compile=self._config.compile and "cuda" in self._device,
device=self._device,
)

Expand Down
2 changes: 2 additions & 0 deletions d3rlpy/algos/qlearning/cal_ql.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class CalQLConfig(CQLConfig):
:math:`\log{\sum_a \exp{Q(s, a)}}`.
soft_q_backup (bool): Flag to use SAC-style backup.
max_q_backup (bool): Flag to sample max Q-values for target.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
"""

def create(
Expand Down Expand Up @@ -171,6 +172,7 @@ def inner_create_impl(
n_action_samples=self._config.n_action_samples,
soft_q_backup=self._config.soft_q_backup,
max_q_backup=self._config.max_q_backup,
compile=self._config.compile,
device=self._device,
)

Expand Down
6 changes: 6 additions & 0 deletions d3rlpy/algos/qlearning/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class CQLConfig(LearnableConfig):
:math:`\log{\sum_a \exp{Q(s, a)}}`.
soft_q_backup (bool): Flag to use SAC-style backup.
max_q_backup (bool): Flag to sample max Q-values for target.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 1e-4
Expand All @@ -124,6 +125,7 @@ class CQLConfig(LearnableConfig):
n_action_samples: int = 10
soft_q_backup: bool = False
max_q_backup: bool = False
compile: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -225,6 +227,7 @@ def inner_create_impl(
n_action_samples=self._config.n_action_samples,
soft_q_backup=self._config.soft_q_backup,
max_q_backup=self._config.max_q_backup,
compile=self._config.compile and "cuda" in self._device,
device=self._device,
)

Expand Down Expand Up @@ -272,6 +275,7 @@ class DiscreteCQLConfig(LearnableConfig):
target_update_interval (int): Interval to synchronize the target
network.
alpha (float): math:`\alpha` value above.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
"""

learning_rate: float = 6.25e-5
Expand All @@ -283,6 +287,7 @@ class DiscreteCQLConfig(LearnableConfig):
n_critics: int = 1
target_update_interval: int = 8000
alpha: float = 1.0
compile: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -336,6 +341,7 @@ def inner_create_impl(
target_update_interval=self._config.target_update_interval,
gamma=self._config.gamma,
alpha=self._config.alpha,
compile=self._config.compile and "cuda" in self._device,
device=self._device,
)

Expand Down
3 changes: 3 additions & 0 deletions d3rlpy/algos/qlearning/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class DDPGConfig(LearnableConfig):
gamma (float): Discount factor.
tau (float): Target network synchronization coefficiency.
n_critics (int): Number of Q functions for ensemble.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
"""

batch_size: int = 256
Expand All @@ -81,6 +82,7 @@ class DDPGConfig(LearnableConfig):
q_func_factory: QFunctionFactory = make_q_func_field()
tau: float = 0.005
n_critics: int = 1
compile: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -153,6 +155,7 @@ def inner_create_impl(
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
tau=self._config.tau,
compile=self._config.compile,
device=self._device,
)

Expand Down
6 changes: 6 additions & 0 deletions d3rlpy/algos/qlearning/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class DQNConfig(LearnableConfig):
gamma (float): Discount factor.
n_critics (int): Number of Q functions for ensemble.
target_update_interval (int): Interval to update the target network.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
"""

batch_size: int = 32
Expand All @@ -54,6 +55,7 @@ class DQNConfig(LearnableConfig):
gamma: float = 0.99
n_critics: int = 1
target_update_interval: int = 8000
compile: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -106,6 +108,7 @@ def inner_create_impl(
target_update_interval=self._config.target_update_interval,
modules=modules,
gamma=self._config.gamma,
compile=self._config.compile,
device=self._device,
)

Expand Down Expand Up @@ -151,6 +154,7 @@ class DoubleDQNConfig(DQNConfig):
n_critics (int): Number of Q functions.
target_update_interval (int): Interval to synchronize the target
network.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
"""

batch_size: int = 32
Expand All @@ -161,6 +165,7 @@ class DoubleDQNConfig(DQNConfig):
gamma: float = 0.99
n_critics: int = 1
target_update_interval: int = 8000
compile: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -213,6 +218,7 @@ def inner_create_impl(
targ_q_func_forwarder=targ_forwarder,
target_update_interval=self._config.target_update_interval,
gamma=self._config.gamma,
compile=self._config.compile,
device=self._device,
)

Expand Down
3 changes: 3 additions & 0 deletions d3rlpy/algos/qlearning/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class IQLConfig(LearnableConfig):
weight_temp (float): Inverse temperature value represented as
:math:`\beta`.
max_weight (float): Maximum advantage weight value to clip.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 3e-4
Expand All @@ -96,6 +97,7 @@ class IQLConfig(LearnableConfig):
expectile: float = 0.7
weight_temp: float = 3.0
max_weight: float = 100.0
compile: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -175,6 +177,7 @@ def inner_create_impl(
expectile=self._config.expectile,
weight_temp=self._config.weight_temp,
max_weight=self._config.max_weight,
compile=self._config.compile and "cuda" in self._device,
device=self._device,
)

Expand Down
5 changes: 5 additions & 0 deletions d3rlpy/algos/qlearning/plas.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class PLASConfig(LearnableConfig):
lam (float): Weight factor for critic ensemble.
warmup_steps (int): Number of steps to warmup the VAE.
beta (float): KL reguralization term for Conditional VAE.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 1e-4
Expand All @@ -96,6 +97,7 @@ class PLASConfig(LearnableConfig):
lam: float = 0.75
warmup_steps: int = 500000
beta: float = 0.5
compile: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -197,6 +199,7 @@ def inner_create_impl(
lam=self._config.lam,
beta=self._config.beta,
warmup_steps=self._config.warmup_steps,
compile=self._config.compile and "cuda" in self._device,
device=self._device,
)

Expand Down Expand Up @@ -247,6 +250,7 @@ class PLASWithPerturbationConfig(PLASConfig):
action_flexibility (float): Output scale of perturbation layer.
warmup_steps (int): Number of steps to warmup the VAE.
beta (float): KL reguralization term for Conditional VAE.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
"""

action_flexibility: float = 0.05
Expand Down Expand Up @@ -373,6 +377,7 @@ def inner_create_impl(
lam=self._config.lam,
beta=self._config.beta,
warmup_steps=self._config.warmup_steps,
compile=self._config.compile and "cuda" in self._device,
device=self._device,
)

Expand Down
3 changes: 3 additions & 0 deletions d3rlpy/algos/qlearning/rebrac.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class ReBRACConfig(LearnableConfig):
critic_beta (float): :math:`\beta_2` value.
update_actor_interval (int): Interval to update policy function
described as `delayed policy update` in the paper.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 1e-3
Expand All @@ -89,6 +90,7 @@ class ReBRACConfig(LearnableConfig):
actor_beta: float = 0.001
critic_beta: float = 0.01
update_actor_interval: int = 2
compile: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -166,6 +168,7 @@ def inner_create_impl(
actor_beta=self._config.actor_beta,
critic_beta=self._config.critic_beta,
update_actor_interval=self._config.update_actor_interval,
compile=self._config.compile and "cuda" in self._device,
device=self._device,
)

Expand Down
3 changes: 3 additions & 0 deletions d3rlpy/algos/qlearning/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class SACConfig(LearnableConfig):
tau (float): Target network synchronization coefficiency.
n_critics (int): Number of Q functions for ensemble.
initial_temperature (float): Initial temperature value.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 3e-4
Expand All @@ -110,6 +111,7 @@ class SACConfig(LearnableConfig):
tau: float = 0.005
n_critics: int = 2
initial_temperature: float = 1.0
compile: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -188,6 +190,7 @@ def inner_create_impl(
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
tau=self._config.tau,
compile=self._config.compile and "cuda" in self._device,
device=self._device,
)

Expand Down
3 changes: 3 additions & 0 deletions d3rlpy/algos/qlearning/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class TD3Config(LearnableConfig):
target_smoothing_clip (float): Clipping range for target noise.
update_actor_interval (int): Interval to update policy function
described as `delayed policy update` in the paper.
compile (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 3e-4
Expand All @@ -90,6 +91,7 @@ class TD3Config(LearnableConfig):
target_smoothing_sigma: float = 0.2
target_smoothing_clip: float = 0.5
update_actor_interval: int = 2
compile: bool = False

def create(
self, device: DeviceArg = False, enable_ddp: bool = False
Expand Down Expand Up @@ -165,6 +167,7 @@ def inner_create_impl(
target_smoothing_sigma=self._config.target_smoothing_sigma,
target_smoothing_clip=self._config.target_smoothing_clip,
update_actor_interval=self._config.update_actor_interval,
compile=self._config.compile,
device=self._device,
)

Expand Down
Loading

0 comments on commit f6de602

Please sign in to comment.