Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sioconvpsのみを学習するconfigファイルを追加 #191

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions configs/models/learn_only_sioconvps_large.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
image_encoder: i_jepa_target_encoder # Alias for ImageEncodingAgent.
image_decoder: i_jepa_target_visualization_decoder # Alias for MultiStepCuriosityAgent.

i_jepa_target_encoder:
_target_: ami.models.model_wrapper.ModelWrapper
default_device: ${devices.0}
has_inference: True
parameter_file: ???
inference_forward:
_target_: hydra.utils.get_method
path: ami.models.bool_mask_i_jepa.i_jepa_encoder_infer
model:
_target_: ami.models.bool_mask_i_jepa.BoolMaskIJEPAEncoder
img_size:
- ${shared.image_height} # Assuming 84
- ${shared.image_width} # Assuming 84
in_channels: ${shared.image_channels} # Assume 3
patch_size: 12
embed_dim: 512
out_dim: 32 # 1 / 13.5
depth: 4
num_heads: 4
mlp_ratio: 4.0

forward_dynamics:
_target_: ami.models.model_wrapper.ModelWrapper
default_device: ${devices.0}
has_inference: True
model:
_target_: ami.models.forward_dynamics.ForwardDynamcisWithActionReward
observation_flatten:
_target_: ami.models.components.stacked_features.LerpStackedFeatures
dim_in: ${models.i_jepa_target_encoder.model.out_dim}
dim_out: 2048
num_stack: ${python.eval:"12 * 12"} # assuming image size is 144x144, patch size is 12x12.
action_flatten:
_target_: ami.models.components.multi_embeddings.MultiEmbeddings
choices_per_category:
_target_: hydra.utils.get_object
path: ami.interactions.environments.actuators.vrchat_osc_discrete_actuator.ACTION_CHOICES_PER_CATEGORY
embedding_dim: 8
do_flatten: True
obs_action_projection:
_target_: torch.nn.Linear
# action_embedding_dim * num_action_choices + obs_embedding_dim
in_features: ${python.eval:"${..action_flatten.embedding_dim} * 5 + ${..observation_flatten.dim_out}"}
out_features: ${..core_model.dim}
core_model:
_target_: ami.models.components.sioconvps.SioConvPS
depth: 12
dim: 2048
dim_ff_hidden: ${python.eval:"${.dim} * 2"}
dropout: 0.1
obs_hat_dist_head:
# _target_: ami.models.components.stacked_features.ToStackedFeaturesMDN
# dim_in: ${..core_model.dim}
# dim_out: ${models.i_jepa_target_encoder.model.out_dim}
# num_stack: ${..observation_flatten.num_stack}
# num_components: 8
# eps: 0.01
_target_: torch.nn.Sequential
_args_:
- _target_: ami.models.components.stacked_features.ToStackedFeatures
dim_in: ${....core_model.dim}
dim_out: ${models.i_jepa_target_encoder.model.out_dim}
num_stack: ${....observation_flatten.num_stack}
- _target_: ami.models.components.fully_connected_fixed_std_normal.FullyConnectedFixedStdNormal
dim_in: ${..0.dim_out}
dim_out: ${models.i_jepa_target_encoder.model.out_dim}
normal_cls: Deterministic
action_hat_dist_head:
_target_: ami.models.components.discrete_policy_head.DiscretePolicyHead
dim_in: ${..core_model.dim}
action_choices_per_category:
_target_: hydra.utils.get_object
path: ami.interactions.environments.actuators.vrchat_osc_discrete_actuator.ACTION_CHOICES_PER_CATEGORY
reward_hat_dist_head:
_target_: ami.models.components.fully_connected_fixed_std_normal.FullyConnectedFixedStdNormal
dim_in: ${..core_model.dim}
dim_out: 1
# _target_: ami.models.components.mixture_desity_network.NormalMixtureDensityNetwork
# in_features: ${..core_model.dim}
# out_features: 1
# num_components: 8
squeeze_feature_dim: True

policy:
_target_: ami.models.model_wrapper.ModelWrapper
default_device: ${devices.0}
has_inference: True
inference_thread_only: True
model:
_target_: ami.models.csv_files_policy.CSVFilesPolicy
file_paths: ???
column_headers: ???
dtype: ${torch.dtype:long}
device: ${devices.0}
observation_ndim: 2

value:
_target_: ami.models.model_wrapper.ModelWrapper
default_device: ${devices.0}
has_inference: True
inference_thread_only: True
model:
_target_: ami.models.policy_or_value_network.PolicyOrValueNetwork
observation_projection:
_target_: ami.models.components.stacked_features.LerpStackedFeatures
dim_in: ${models.i_jepa_target_encoder.model.out_dim}
dim_out: 1
num_stack: ${python.eval:"12 * 12"} # assuming image size is 144x144, patch size is 12x12.
forward_dynamics_hidden_projection:
_target_: ami.models.components.stacked_features.LerpStackedFeatures
dim_in: ${models.forward_dynamics.model.core_model.dim}
dim_out: ${.dim_in}
num_stack: ${models.forward_dynamics.model.core_model.depth}
observation_hidden_projection:
_target_: ami.models.policy_value_common_net.ConcatFlattenedObservationAndLerpedHidden
dim_obs: ${..observation_projection.dim_out}
dim_hidden: ${models.forward_dynamics.model.core_model.dim}
dim_out: 1
core_model:
_target_: ami.models.components.resnet.ResNetFF
dim: ${..observation_hidden_projection.dim_out}
dim_hidden: 1
depth: 1
dist_head:
_target_: ami.models.components.fully_connected_fixed_std_normal.FullyConnectedFixedStdNormal
dim_in: ${..observation_hidden_projection.dim_out}
dim_out: 1
squeeze_feature_dim: True
normal_cls:
_target_: hydra.utils.get_class
path: ami.models.components.fully_connected_fixed_std_normal.DeterministicNormal

i_jepa_target_visualization_decoder:
_target_: ami.models.model_wrapper.ModelWrapper
default_device: ${devices.0}
parameter_file: ???
has_inference: True
model:
_target_: ami.models.i_jepa_latent_visualization_decoder.IJEPALatentVisualizationDecoder
input_n_patches:
- ${python.eval:"${shared.image_height} // ${models.i_jepa_target_encoder.model.patch_size}"}
- ${python.eval:"${shared.image_width} // ${models.i_jepa_target_encoder.model.patch_size}"}
input_latents_dim: ${models.i_jepa_target_encoder.model.out_dim}
decoder_blocks_in_and_out_channels:
- [512, 512]
- [512, 256]
- [256, 128]
- [128, 64]
n_res_blocks: 3
num_heads: 4
Loading