Skip to content

Commit

Permalink
Simplify encoder interface
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Aug 4, 2023
1 parent 1caa8d7 commit 3466976
Show file tree
Hide file tree
Showing 30 changed files with 487 additions and 935 deletions.
12 changes: 10 additions & 2 deletions d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...models.torch import DiscreteImitator, PixelEncoder
from ...models.torch import DiscreteImitator, PixelEncoder, compute_output_size
from ...torch_utility import TorchMiniBatch
from .base import QLearningAlgoBase
from .torch.bcq_impl import BCQImpl, DiscreteBCQImpl
Expand Down Expand Up @@ -338,8 +338,16 @@ def inner_create_impl(

# share convolutional layers if observation is pixel
if isinstance(q_func.q_funcs[0].encoder, PixelEncoder):
hidden_size = compute_output_size(
[observation_shape],
q_func.q_funcs[0].encoder,
device=self._device,
)
imitator = DiscreteImitator(
q_func.q_funcs[0].encoder, action_size, self._config.beta
encoder=q_func.q_funcs[0].encoder,
hidden_size=hidden_size,
action_size=action_size,
beta=self._config.beta,
)
imitator.to(self._device)
else:
Expand Down
20 changes: 17 additions & 3 deletions d3rlpy/algos/qlearning/torch/bc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Policy,
ProbablisticRegressor,
SquashedNormalPolicy,
compute_output_size,
)
from ....torch_utility import TorchMiniBatch, hard_sync, train_api
from ..base import QLearningAlgoImplBase
Expand Down Expand Up @@ -94,14 +95,27 @@ def __init__(
@property
def policy(self) -> Policy:
policy: Policy
hidden_size = compute_output_size(
[self._observation_shape, (self._action_size,)],
self._imitator.encoder,
device=self._device,
)
if self._policy_type == "deterministic":
hidden_size = compute_output_size(
[self._observation_shape, (self._action_size,)],
self._imitator.encoder,
device=self._device,
)
policy = DeterministicPolicy(
self._imitator.encoder, self._action_size
encoder=self._imitator.encoder,
hidden_size=hidden_size,
action_size=self._action_size,
)
elif self._policy_type == "stochastic":
return SquashedNormalPolicy(
self._imitator.encoder,
self._action_size,
encoder=self._imitator.encoder,
hidden_size=hidden_size,
action_size=self._action_size,
min_logstd=-4.0,
max_logstd=15.0,
use_std_parameter=False,
Expand Down
96 changes: 78 additions & 18 deletions d3rlpy/models/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
SimplePositionEncoding,
SquashedNormalPolicy,
ValueFunction,
compute_output_size,
)
from .utility import create_activation

Expand Down Expand Up @@ -56,6 +57,7 @@ def create_discrete_q_function(
) -> EnsembleDiscreteQFunction:
if q_func_factory.share_encoder:
encoder = encoder_factory.create(observation_shape)
hidden_size = compute_output_size([observation_shape], encoder, device)
# normalize gradient scale by ensemble size
for p in cast(nn.Module, encoder).parameters():
p.register_hook(lambda grad: grad / n_ensembles)
Expand All @@ -64,7 +66,12 @@ def create_discrete_q_function(
for _ in range(n_ensembles):
if not q_func_factory.share_encoder:
encoder = encoder_factory.create(observation_shape)
q_funcs.append(q_func_factory.create_discrete(encoder, action_size))
hidden_size = compute_output_size(
[observation_shape], encoder, device
)
q_funcs.append(
q_func_factory.create_discrete(encoder, hidden_size, action_size)
)
q_func = EnsembleDiscreteQFunction(q_funcs)
q_func.to(device)
return q_func
Expand All @@ -82,6 +89,9 @@ def create_continuous_q_function(
encoder = encoder_factory.create_with_action(
observation_shape, action_size
)
hidden_size = compute_output_size(
[observation_shape, (action_size,)], encoder, device
)
# normalize gradient scale by ensemble size
for p in cast(nn.Module, encoder).parameters():
p.register_hook(lambda grad: grad / n_ensembles)
Expand All @@ -92,7 +102,12 @@ def create_continuous_q_function(
encoder = encoder_factory.create_with_action(
observation_shape, action_size
)
q_funcs.append(q_func_factory.create_continuous(encoder))
hidden_size = compute_output_size(
[observation_shape, (action_size,)], encoder, device
)
q_funcs.append(
q_func_factory.create_continuous(encoder, hidden_size, action_size)
)
q_func = EnsembleContinuousQFunction(q_funcs)
q_func.to(device)
return q_func
Expand All @@ -105,7 +120,12 @@ def create_deterministic_policy(
device: str,
) -> DeterministicPolicy:
encoder = encoder_factory.create(observation_shape)
policy = DeterministicPolicy(encoder, action_size)
hidden_size = compute_output_size([observation_shape], encoder, device)
policy = DeterministicPolicy(
encoder=encoder,
hidden_size=hidden_size,
action_size=action_size,
)
policy.to(device)
return policy

Expand All @@ -118,7 +138,15 @@ def create_deterministic_residual_policy(
device: str,
) -> DeterministicResidualPolicy:
encoder = encoder_factory.create_with_action(observation_shape, action_size)
policy = DeterministicResidualPolicy(encoder, scale)
hidden_size = compute_output_size(
[observation_shape, (action_size,)], encoder, device
)
policy = DeterministicResidualPolicy(
encoder=encoder,
hidden_size=hidden_size,
action_size=action_size,
scale=scale,
)
policy.to(device)
return policy

Expand All @@ -133,9 +161,11 @@ def create_squashed_normal_policy(
use_std_parameter: bool = False,
) -> SquashedNormalPolicy:
encoder = encoder_factory.create(observation_shape)
hidden_size = compute_output_size([observation_shape], encoder, device)
policy = SquashedNormalPolicy(
encoder,
action_size,
encoder=encoder,
hidden_size=hidden_size,
action_size=action_size,
min_logstd=min_logstd,
max_logstd=max_logstd,
use_std_parameter=use_std_parameter,
Expand All @@ -154,9 +184,11 @@ def create_non_squashed_normal_policy(
use_std_parameter: bool = False,
) -> NonSquashedNormalPolicy:
encoder = encoder_factory.create(observation_shape)
hidden_size = compute_output_size([observation_shape], encoder, device)
policy = NonSquashedNormalPolicy(
encoder,
action_size,
encoder=encoder,
hidden_size=hidden_size,
action_size=action_size,
min_logstd=min_logstd,
max_logstd=max_logstd,
use_std_parameter=use_std_parameter,
Expand All @@ -172,7 +204,10 @@ def create_categorical_policy(
device: str,
) -> CategoricalPolicy:
encoder = encoder_factory.create(observation_shape)
policy = CategoricalPolicy(encoder, action_size)
hidden_size = compute_output_size([observation_shape], encoder, device)
policy = CategoricalPolicy(
encoder=encoder, hidden_size=hidden_size, action_size=action_size
)
policy.to(device)
return policy

Expand All @@ -193,10 +228,16 @@ def create_conditional_vae(
decoder_encoder = encoder_factory.create_with_action(
observation_shape, latent_size
)
hidden_size = compute_output_size(
[observation_shape, (action_size,)], encoder_encoder, device
)
policy = ConditionalVAE(
encoder_encoder,
decoder_encoder,
beta,
encoder_encoder=encoder_encoder,
decoder_encoder=decoder_encoder,
hidden_size=hidden_size,
latent_size=latent_size,
action_size=action_size,
beta=beta,
min_logstd=min_logstd,
max_logstd=max_logstd,
)
Expand All @@ -212,7 +253,13 @@ def create_discrete_imitator(
device: str,
) -> DiscreteImitator:
encoder = encoder_factory.create(observation_shape)
imitator = DiscreteImitator(encoder, action_size, beta)
hidden_size = compute_output_size([observation_shape], encoder, device)
imitator = DiscreteImitator(
encoder=encoder,
hidden_size=hidden_size,
action_size=action_size,
beta=beta,
)
imitator.to(device)
return imitator

Expand All @@ -224,7 +271,12 @@ def create_deterministic_regressor(
device: str,
) -> DeterministicRegressor:
encoder = encoder_factory.create(observation_shape)
regressor = DeterministicRegressor(encoder, action_size)
hidden_size = compute_output_size([observation_shape], encoder, device)
regressor = DeterministicRegressor(
encoder=encoder,
hidden_size=hidden_size,
action_size=action_size,
)
regressor.to(device)
return regressor

Expand All @@ -238,8 +290,13 @@ def create_probablistic_regressor(
max_logstd: float = 2.0,
) -> ProbablisticRegressor:
encoder = encoder_factory.create(observation_shape)
hidden_size = compute_output_size([observation_shape], encoder, device)
regressor = ProbablisticRegressor(
encoder, action_size, min_logstd=min_logstd, max_logstd=max_logstd
encoder=encoder,
hidden_size=hidden_size,
action_size=action_size,
min_logstd=min_logstd,
max_logstd=max_logstd,
)
regressor.to(device)
return regressor
Expand All @@ -249,7 +306,8 @@ def create_value_function(
observation_shape: Shape, encoder_factory: EncoderFactory, device: str
) -> ValueFunction:
encoder = encoder_factory.create(observation_shape)
value_func = ValueFunction(encoder)
hidden_size = compute_output_size([observation_shape], encoder, device)
value_func = ValueFunction(encoder, hidden_size)
value_func.to(device)
return value_func

Expand Down Expand Up @@ -279,7 +337,7 @@ def create_continuous_decision_transformer(
device: str,
) -> ContinuousDecisionTransformer:
encoder = encoder_factory.create(observation_shape)
hidden_size = encoder.get_feature_size()
hidden_size = compute_output_size([observation_shape], encoder, device)

if position_encoding_type == "simple":
position_encoding = SimplePositionEncoding(hidden_size, max_timestep)
Expand All @@ -294,6 +352,7 @@ def create_continuous_decision_transformer(

transformer = ContinuousDecisionTransformer(
encoder=encoder,
feature_size=hidden_size,
position_encoding=position_encoding,
action_size=action_size,
num_heads=num_heads,
Expand Down Expand Up @@ -324,7 +383,7 @@ def create_discrete_decision_transformer(
device: str,
) -> DiscreteDecisionTransformer:
encoder = encoder_factory.create(observation_shape)
hidden_size = encoder.get_feature_size()
hidden_size = compute_output_size([observation_shape], encoder, device)

if position_encoding_type == "simple":
position_encoding = SimplePositionEncoding(hidden_size, max_timestep)
Expand All @@ -339,6 +398,7 @@ def create_discrete_decision_transformer(

transformer = DiscreteDecisionTransformer(
encoder=encoder,
feature_size=hidden_size,
position_encoding=position_encoding,
action_size=action_size,
num_heads=num_heads,
Expand Down
Loading

0 comments on commit 3466976

Please sign in to comment.