-
Notifications
You must be signed in to change notification settings - Fork 438
/
factory_task_insertion.py
199 lines (152 loc) · 9.07 KB
/
factory_task_insertion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# Copyright (c) 2021-2023, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Factory: Class for insertion task.
Inherits insertion environment class and abstract task class (not enforced). Can be executed with
python train.py task=FactoryTaskInsertion
Only the environment is provided; training a successful RL policy is an open research problem left to the user.
"""
import hydra
import math
import omegaconf
import os
import torch
from isaacgym import gymapi, gymtorch
from isaacgymenvs.tasks.factory.factory_env_insertion import FactoryEnvInsertion
from isaacgymenvs.tasks.factory.factory_schema_class_task import FactoryABCTask
from isaacgymenvs.tasks.factory.factory_schema_config_task import FactorySchemaConfigTask
class FactoryTaskInsertion(FactoryEnvInsertion, FactoryABCTask):
def __init__(self, cfg, rl_device, sim_device, graphics_device_id, headless, virtual_screen_capture, force_render):
"""Initialize instance variables. Initialize task superclass."""
super().__init__(cfg, rl_device, sim_device, graphics_device_id, headless, virtual_screen_capture, force_render)
self.cfg = cfg
self._get_task_yaml_params()
if self.viewer != None:
self._set_viewer_params()
if self.cfg_base.mode.export_scene:
self.export_scene(label='franka_task_insertion')
def _get_task_yaml_params(self):
"""Initialize instance variables from YAML files."""
cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name='factory_schema_config_task', node=FactorySchemaConfigTask)
self.cfg_task = omegaconf.OmegaConf.create(self.cfg)
self.max_episode_length = self.cfg_task.rl.max_episode_length # required instance var for VecTask
asset_info_path = '../../assets/factory/yaml/factory_asset_info_insertion.yaml' # relative to Gym's Hydra search path (cfg dir)
self.asset_info_insertion = hydra.compose(config_name=asset_info_path)
self.asset_info_insertion = self.asset_info_insertion['']['']['']['']['']['']['assets']['factory']['yaml'] # strip superfluous nesting
ppo_path = 'train/FactoryTaskInsertionPPO.yaml' # relative to Gym's Hydra search path (cfg dir)
self.cfg_ppo = hydra.compose(config_name=ppo_path)
self.cfg_ppo = self.cfg_ppo['train'] # strip superfluous nesting
def _acquire_task_tensors(self):
"""Acquire tensors."""
pass
def _refresh_task_tensors(self):
"""Refresh tensors."""
pass
def pre_physics_step(self, actions):
"""Reset environments. Apply actions from policy as position/rotation targets, force/torque targets, and/or PD gains."""
env_ids = self.reset_buf.nonzero(as_tuple=False).squeeze(-1)
if len(env_ids) > 0:
self.reset_idx(env_ids)
self._actions = actions.clone().to(self.device) # shape = (num_envs, num_actions); values = [-1, 1]
def post_physics_step(self):
"""Step buffers. Refresh tensors. Compute observations and reward."""
self.progress_buf[:] += 1
self.refresh_base_tensors()
self.refresh_env_tensors()
self._refresh_task_tensors()
self.compute_observations()
self.compute_reward()
def compute_observations(self):
"""Compute observations."""
return self.obs_buf # shape = (num_envs, num_observations)
def compute_reward(self):
"""Detect successes and failures. Update reward and reset buffers."""
self._update_rew_buf()
self._update_reset_buf()
def _update_rew_buf(self):
"""Compute reward at current timestep."""
pass
def _update_reset_buf(self):
"""Assign environments for reset if successful or failed."""
pass
def reset_idx(self, env_ids):
"""Reset specified environments."""
self._reset_franka(env_ids)
self._reset_object(env_ids)
self.reset_buf[env_ids] = 0
self.progress_buf[env_ids] = 0
def _reset_franka(self, env_ids):
"""Reset DOF states and DOF targets of Franka."""
# shape of dof_pos = (num_envs, num_dofs)
# shape of dof_vel = (num_envs, num_dofs)
# Initialize Franka to middle of joint limits, plus joint noise
franka_dof_props = self.gym.get_actor_dof_properties(self.env_ptrs[0],
self.franka_handles[0]) # same across all envs
lower_lims = franka_dof_props['lower']
upper_lims = franka_dof_props['upper']
self.dof_pos[:, 0:self.franka_num_dofs] = torch.tensor((lower_lims + upper_lims) * 0.5, device=self.device) \
+ (torch.rand((self.num_envs, 1),
device=self.device) * 2.0 - 1.0) * self.cfg_task.randomize.joint_noise * math.pi / 180
self.dof_vel[env_ids, 0:self.franka_num_dofs] = 0.0
franka_actor_ids_sim_int32 = self.franka_actor_ids_sim.to(dtype=torch.int32, device=self.device)[env_ids]
self.gym.set_dof_state_tensor_indexed(self.sim,
gymtorch.unwrap_tensor(self.dof_state),
gymtorch.unwrap_tensor(franka_actor_ids_sim_int32),
len(franka_actor_ids_sim_int32))
self.ctrl_target_dof_pos[env_ids, 0:self.franka_num_dofs] = self.dof_pos[env_ids, 0:self.franka_num_dofs]
self.gym.set_dof_position_target_tensor(self.sim, gymtorch.unwrap_tensor(self.ctrl_target_dof_pos))
def _reset_object(self, env_ids):
"""Reset root state of plug."""
# shape of root_pos = (num_envs, num_actors, 3)
# shape of root_quat = (num_envs, num_actors, 4)
# shape of root_linvel = (num_envs, num_actors, 3)
# shape of root_angvel = (num_envs, num_actors, 3)
if self.cfg_task.randomize.initial_state == 'random':
self.root_pos[env_ids, self.plug_actor_id_env] = \
torch.cat(((torch.rand((self.num_envs, 1), device=self.device) * 2.0 - 1.0) * self.cfg_task.randomize.plug_noise_xy,
self.cfg_task.randomize.plug_bias_y + (torch.rand((self.num_envs, 1), device=self.device) * 2.0 - 1.0) * self.cfg_task.randomize.plug_noise_xy,
torch.ones((self.num_envs, 1), device=self.device) * (self.cfg_base.env.table_height + self.cfg_task.randomize.plug_bias_z)), dim=1)
elif self.cfg_task.randomize.initial_state == 'goal':
self.root_pos[env_ids, self.plug_actor_id_env] = torch.tensor([0.0, 0.0, self.cfg_base.env.table_height],
device=self.device)
self.root_linvel[env_ids, self.plug_actor_id_env] = 0.0
self.root_angvel[env_ids, self.plug_actor_id_env] = 0.0
plug_actor_ids_sim_int32 = self.plug_actor_ids_sim.to(dtype=torch.int32, device=self.device)
self.gym.set_actor_root_state_tensor_indexed(self.sim,
gymtorch.unwrap_tensor(self.root_state),
gymtorch.unwrap_tensor(plug_actor_ids_sim_int32[env_ids]),
len(plug_actor_ids_sim_int32[env_ids]))
def _reset_buffers(self, env_ids):
"""Reset buffers. """
self.reset_buf[env_ids] = 0
self.progress_buf[env_ids] = 0
def _set_viewer_params(self):
"""Set viewer parameters."""
cam_pos = gymapi.Vec3(-1.0, -1.0, 1.0)
cam_target = gymapi.Vec3(0.0, 0.0, 0.5)
self.gym.viewer_camera_look_at(self.viewer, None, cam_pos, cam_target)