Skip to content

Commit

Permalink
Modify OptimFactory interface to take named_modules instead of parame…
Browse files Browse the repository at this point in the history
…ters
  • Loading branch information
takuseno committed Oct 7, 2023
1 parent 468ce52 commit 154f3c0
Show file tree
Hide file tree
Showing 18 changed files with 90 additions and 74 deletions.
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/awac.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.parameters(), lr=self._config.actor_learning_rate
policy.named_modules(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.parameters(), lr=self._config.critic_learning_rate
q_funcs.named_modules(), lr=self._config.critic_learning_rate
)

dummy_log_temp = Parameter(torch.zeros(1, 1))
Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def inner_create_impl(
raise ValueError(f"invalid policy_type: {self._config.policy_type}")

optim = self._config.optim_factory.create(
imitator.parameters(), lr=self._config.learning_rate
imitator.named_modules(), lr=self._config.learning_rate
)

modules = BCModules(optim=optim, imitator=imitator)
Expand Down Expand Up @@ -159,7 +159,7 @@ def inner_create_impl(
)

optim = self._config.optim_factory.create(
imitator.parameters(), lr=self._config.learning_rate
imitator.named_modules(), lr=self._config.learning_rate
)

modules = DiscreteBCModules(optim=optim, imitator=imitator)
Expand Down
18 changes: 6 additions & 12 deletions d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,13 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.parameters(), lr=self._config.actor_learning_rate
policy.named_modules(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.parameters(), lr=self._config.critic_learning_rate
q_funcs.named_modules(), lr=self._config.critic_learning_rate
)
imitator_optim = self._config.imitator_optim_factory.create(
imitator.parameters(), lr=self._config.imitator_learning_rate
imitator.named_modules(), lr=self._config.imitator_learning_rate
)

modules = BCQModules(
Expand Down Expand Up @@ -373,16 +373,10 @@ def inner_create_impl(
device=self._device,
)

# TODO: replace this with a cleaner way
# retrieve unique elements
q_func_params = list(q_funcs.parameters())
imitator_params = list(imitator.parameters())
unique_dict = {}
for param in q_func_params + imitator_params:
unique_dict[param] = param
unique_params = list(unique_dict.values())
q_func_params = list(q_funcs.named_modules())
imitator_params = list(imitator.named_modules())
optim = self._config.optim_factory.create(
unique_params, lr=self._config.learning_rate
q_func_params + imitator_params, lr=self._config.learning_rate
)

modules = DiscreteBCQModules(
Expand Down
10 changes: 5 additions & 5 deletions d3rlpy/algos/qlearning/bear.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,19 +197,19 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.parameters(), lr=self._config.actor_learning_rate
policy.named_modules(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.parameters(), lr=self._config.critic_learning_rate
q_funcs.named_modules(), lr=self._config.critic_learning_rate
)
imitator_optim = self._config.imitator_optim_factory.create(
imitator.parameters(), lr=self._config.imitator_learning_rate
imitator.named_modules(), lr=self._config.imitator_learning_rate
)
temp_optim = self._config.temp_optim_factory.create(
log_temp.parameters(), lr=self._config.temp_learning_rate
log_temp.named_modules(), lr=self._config.temp_learning_rate
)
alpha_optim = self._config.alpha_optim_factory.create(
log_alpha.parameters(), lr=self._config.actor_learning_rate
log_alpha.named_modules(), lr=self._config.actor_learning_rate
)

modules = BEARModules(
Expand Down
10 changes: 5 additions & 5 deletions d3rlpy/algos/qlearning/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,20 +166,20 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.parameters(), lr=self._config.actor_learning_rate
policy.named_modules(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.parameters(), lr=self._config.critic_learning_rate
q_funcs.named_modules(), lr=self._config.critic_learning_rate
)
if self._config.temp_learning_rate > 0:
temp_optim = self._config.temp_optim_factory.create(
log_temp.parameters(), lr=self._config.temp_learning_rate
log_temp.named_modules(), lr=self._config.temp_learning_rate
)
else:
temp_optim = None
if self._config.alpha_learning_rate > 0:
alpha_optim = self._config.alpha_optim_factory.create(
log_alpha.parameters(), lr=self._config.alpha_learning_rate
log_alpha.named_modules(), lr=self._config.alpha_learning_rate
)
else:
alpha_optim = None
Expand Down Expand Up @@ -296,7 +296,7 @@ def inner_create_impl(
)

optim = self._config.optim_factory.create(
q_funcs.parameters(), lr=self._config.learning_rate
q_funcs.named_modules(), lr=self._config.learning_rate
)

modules = DQNModules(
Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/crr.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,10 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.parameters(), lr=self._config.actor_learning_rate
policy.named_modules(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.parameters(), lr=self._config.critic_learning_rate
q_funcs.named_modules(), lr=self._config.critic_learning_rate
)

modules = CRRModules(
Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.parameters(), lr=self._config.actor_learning_rate
policy.named_modules(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.parameters(), lr=self._config.critic_learning_rate
q_funcs.named_modules(), lr=self._config.critic_learning_rate
)

modules = DDPGModules(
Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def inner_create_impl(
)

optim = self._config.optim_factory.create(
q_funcs.parameters(), lr=self._config.learning_rate
q_funcs.named_modules(), lr=self._config.learning_rate
)

modules = DQNModules(
Expand Down Expand Up @@ -186,7 +186,7 @@ def inner_create_impl(
)

optim = self._config.optim_factory.create(
q_funcs.parameters(), lr=self._config.learning_rate
q_funcs.named_modules(), lr=self._config.learning_rate
)

modules = DQNModules(
Expand Down
6 changes: 3 additions & 3 deletions d3rlpy/algos/qlearning/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.parameters(), lr=self._config.actor_learning_rate
policy.named_modules(), lr=self._config.actor_learning_rate
)
q_func_params = list(q_funcs.parameters())
v_func_params = list(value_func.parameters())
q_func_params = list(q_funcs.named_modules())
v_func_params = list(value_func.named_modules())
critic_optim = self._config.critic_optim_factory.create(
q_func_params + v_func_params, lr=self._config.critic_learning_rate
)
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/nfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def inner_create_impl(
)

optim = self._config.optim_factory.create(
q_funcs.parameters(), lr=self._config.learning_rate
q_funcs.named_modules(), lr=self._config.learning_rate
)

modules = DQNModules(
Expand Down
16 changes: 8 additions & 8 deletions d3rlpy/algos/qlearning/plas.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,13 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.parameters(), lr=self._config.actor_learning_rate
policy.named_modules(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.parameters(), lr=self._config.critic_learning_rate
q_funcs.named_modules(), lr=self._config.critic_learning_rate
)
imitator_optim = self._config.critic_optim_factory.create(
imitator.parameters(), lr=self._config.imitator_learning_rate
imitator.named_modules(), lr=self._config.imitator_learning_rate
)

modules = PLASModules(
Expand Down Expand Up @@ -296,16 +296,16 @@ def inner_create_impl(
device=self._device,
)

parameters = list(policy.parameters())
parameters += list(perturbation.parameters())
named_modules = list(policy.named_modules())
named_modules += list(perturbation.named_modules())
actor_optim = self._config.actor_optim_factory.create(
params=parameters, lr=self._config.actor_learning_rate
named_modules, lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.parameters(), lr=self._config.critic_learning_rate
q_funcs.named_modules(), lr=self._config.critic_learning_rate
)
imitator_optim = self._config.critic_optim_factory.create(
imitator.parameters(), lr=self._config.imitator_learning_rate
imitator.named_modules(), lr=self._config.imitator_learning_rate
)

modules = PLASWithPerturbationModules(
Expand Down
12 changes: 6 additions & 6 deletions d3rlpy/algos/qlearning/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,14 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.parameters(), lr=self._config.actor_learning_rate
policy.named_modules(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.parameters(), lr=self._config.critic_learning_rate
q_funcs.named_modules(), lr=self._config.critic_learning_rate
)
if self._config.temp_learning_rate > 0:
temp_optim = self._config.temp_optim_factory.create(
log_temp.parameters(), lr=self._config.temp_learning_rate
log_temp.named_modules(), lr=self._config.temp_learning_rate
)
else:
temp_optim = None
Expand Down Expand Up @@ -299,14 +299,14 @@ def inner_create_impl(
)

critic_optim = self._config.critic_optim_factory.create(
q_funcs.parameters(), lr=self._config.critic_learning_rate
q_funcs.named_modules(), lr=self._config.critic_learning_rate
)
actor_optim = self._config.actor_optim_factory.create(
policy.parameters(), lr=self._config.actor_learning_rate
policy.named_modules(), lr=self._config.actor_learning_rate
)
if self._config.temp_learning_rate > 0:
temp_optim = self._config.temp_optim_factory.create(
log_temp.parameters(), lr=self._config.temp_learning_rate
log_temp.named_modules(), lr=self._config.temp_learning_rate
)
else:
temp_optim = None
Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.parameters(), lr=self._config.actor_learning_rate
policy.named_modules(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.parameters(), lr=self._config.critic_learning_rate
q_funcs.named_modules(), lr=self._config.critic_learning_rate
)

modules = DDPGModules(
Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/qlearning/td3_plus_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.parameters(), lr=self._config.actor_learning_rate
policy.named_modules(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.parameters(), lr=self._config.critic_learning_rate
q_funcs.named_modules(), lr=self._config.critic_learning_rate
)

modules = DDPGModules(
Expand Down
4 changes: 2 additions & 2 deletions d3rlpy/algos/transformer/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def inner_create_impl(
device=self._device,
)
optim = self._config.optim_factory.create(
transformer.parameters(), lr=self._config.learning_rate
transformer.named_modules(), lr=self._config.learning_rate
)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optim, lambda steps: min((steps + 1) / self._config.warmup_steps, 1)
Expand Down Expand Up @@ -234,7 +234,7 @@ def inner_create_impl(
device=self._device,
)
optim = self._config.optim_factory.create(
transformer.parameters(), lr=self._config.learning_rate
transformer.named_modules(), lr=self._config.learning_rate
)
# JIT compile
if self._config.compile:
Expand Down
Loading

0 comments on commit 154f3c0

Please sign in to comment.