Skip to content

Commit

Permalink
Fix lint error
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 3, 2024
1 parent 911d542 commit ebd4756
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/torch/bcq_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
self._beta = beta
self._rl_start_step = rl_start_step
self._compute_imitator_grad = (
CudaGraphWrapper(self.compute_imitator_grad) # type: ignore
CudaGraphWrapper(self.compute_imitator_grad)
if compile_graph
else self.compute_imitator_grad
)
Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/torch/bear_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ def __init__(
self._vae_kl_weight = vae_kl_weight
self._warmup_steps = warmup_steps
self._compute_warmup_actor_grad = (
CudaGraphWrapper(self.compute_warmup_actor_grad) # type: ignore
CudaGraphWrapper(self.compute_warmup_actor_grad)
if compile_graph
else self.compute_warmup_actor_grad
)
self._compute_imitator_grad = (
CudaGraphWrapper(self.compute_imitator_grad) # type: ignore
CudaGraphWrapper(self.compute_imitator_grad)
if compile_graph
else self.compute_imitator_grad
)
Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/torch/ddpg_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ def __init__(
self._q_func_forwarder = q_func_forwarder
self._targ_q_func_forwarder = targ_q_func_forwarder
self._compute_critic_grad = (
CudaGraphWrapper(self.compute_critic_grad) # type: ignore
CudaGraphWrapper(self.compute_critic_grad)
if compile_graph
else self.compute_critic_grad
)
self._compute_actor_grad = (
CudaGraphWrapper(self.compute_actor_grad) # type: ignore
CudaGraphWrapper(self.compute_actor_grad)
if compile_graph
else self.compute_actor_grad
)
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/torch/dqn_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
self._targ_q_func_forwarder = targ_q_func_forwarder
self._target_update_interval = target_update_interval
self._compute_grad = (
CudaGraphWrapper(self.compute_grad) # type: ignore
CudaGraphWrapper(self.compute_grad)
if compile_graph
else self.compute_grad
)
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/torch/plas_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
self._beta = beta
self._warmup_steps = warmup_steps
self._compute_imitator_grad = (
CudaGraphWrapper(self.compute_imitator_grad) # type: ignore
CudaGraphWrapper(self.compute_imitator_grad)
if compile_graph
else self.compute_imitator_grad
)
Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/torch/sac_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,12 @@ def __init__(
self._targ_q_func_forwarder = targ_q_func_forwarder
self._target_update_interval = target_update_interval
self._compute_critic_grad = (
CudaGraphWrapper(self.compute_critic_grad) # type: ignore
CudaGraphWrapper(self.compute_critic_grad)
if compile_graph
else self.compute_critic_grad
)
self._compute_actor_grad = (
CudaGraphWrapper(self.compute_actor_grad) # type: ignore
CudaGraphWrapper(self.compute_actor_grad)
if compile_graph
else self.compute_actor_grad
)
Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/transformer/torch/decision_transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
):
super().__init__(observation_shape, action_size, modules, device)
self._compute_grad = (
CudaGraphWrapper(self.compute_grad) # type: ignore
CudaGraphWrapper(self.compute_grad)
if compile_graph
else self.compute_grad
)
Expand Down Expand Up @@ -122,7 +122,7 @@ def __init__(
self._final_tokens = final_tokens
self._initial_learning_rate = initial_learning_rate
self._compute_grad = (
CudaGraphWrapper(self.compute_grad) # type: ignore
CudaGraphWrapper(self.compute_grad)
if compile_graph
else self.compute_grad
)
Expand Down

0 comments on commit ebd4756

Please sign in to comment.