-
Notifications
You must be signed in to change notification settings - Fork 20
/
collect_pull.py
executable file
·139 lines (121 loc) · 5.25 KB
/
collect_pull.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python2
from __future__ import print_function
import argparse
import os
import sys
import time
sys.path.extend(os.path.abspath(os.path.join(os.getcwd(), d))
for d in ['pddlstream', 'ss-pybullet'])
from pybullet_tools.pr2_primitives import Conf
from pybullet_tools.utils import wait_for_user, elapsed_time, multiply, \
invert, get_link_pose, has_gui, write_json, get_body_name, get_link_name, \
get_joint_name, joint_from_name, get_date, SEPARATOR, safe_remove, link_from_name
from src.utils import CABINET_JOINTS, DRAWER_JOINTS, KNOBS, ZED_LEFT_JOINTS
from src.world import World
from src.streams.press import get_press_gen_fn
from src.streams.pull import get_pull_gen_fn
from src.database import get_joint_reference_pose, get_pull_path, is_press
# TODO: generalize to any manipulation with a fixed entity
def collect_pull(world, joint_name, args):
date = get_date()
#set_seed(args.seed)
robot_name = get_body_name(world.robot)
if is_press(joint_name):
press_gen = get_press_gen_fn(world, collisions=not args.cfree, teleport=args.teleport, learned=False)
else:
joint = joint_from_name(world.kitchen, joint_name)
open_conf = Conf(world.kitchen, [joint], [world.open_conf(joint)])
closed_conf = Conf(world.kitchen, [joint], [world.closed_conf(joint)])
pull_gen = get_pull_gen_fn(world, collisions=not args.cfree, teleport=args.teleport, learned=False)
#handle_link, handle_grasp, _ = get_handle_grasp(world, joint)
path = get_pull_path(robot_name, joint_name)
print(SEPARATOR)
print('Robot name {} | Joint name: {} | Filename: {}'.format(robot_name, joint_name, path))
entries = []
failures = 0
start_time = time.time()
while (len(entries) < args.num_samples) and \
(elapsed_time(start_time) < args.max_time):
if is_press(joint_name):
result = next(press_gen(joint_name), None)
else:
result = next(pull_gen(joint_name, open_conf, closed_conf), None) # Open to closed
if result is None:
print('Failure! | {} / {} [{:.3f}]'.format(
len(entries), args.num_samples, elapsed_time(start_time)))
failures += 1
continue
if not is_press(joint_name):
open_conf.assign()
joint_pose = get_joint_reference_pose(world.kitchen, joint_name)
bq, aq1 = result[:2]
bq.assign()
aq1.assign()
#next(at.commands[2].iterate(None, None))
base_pose = get_link_pose(world.robot, world.base_link)
#handle_pose = get_link_pose(world.robot, base_link)
entries.append({
'joint_from_base': multiply(invert(joint_pose), base_pose),
})
print('Success! | {} / {} [{:.3f}]'.format(
len(entries), args.num_samples, elapsed_time(start_time)))
if has_gui():
wait_for_user()
if not entries:
safe_remove(path)
return None
#visualize_database(joint_from_base_list)
# Assuming the kitchen is fixed but the objects might be open world
# TODO: could store per data point
data = {
'date': date,
'robot_name': robot_name, # get_name | get_body_name | get_base_name | world.robot_name
'base_link': get_link_name(world.robot, world.base_link),
'tool_link': get_link_name(world.robot, world.tool_link),
'kitchen_name': get_body_name(world.kitchen),
'joint_name': joint_name,
'entries': entries,
'failures': failures,
'successes': len(entries),
}
if not is_press(joint_name):
data.update({
'open_conf': open_conf.values,
'closed_conf': closed_conf.values,
})
write_json(path, data)
print('Saved', path)
return data
################################################################################
def main():
parser = argparse.ArgumentParser()
#parser.add_argument('-attempts', default=100, type=int,
# help='The number of attempts')
parser.add_argument('-cfree', action='store_true',
help='When enabled, disables collision checking (for debugging).')
parser.add_argument('-max_time', default=10 * 60, type=float,
help='The maximum runtime')
parser.add_argument('-num_samples', default=1000, type=int,
help='The number of samples')
parser.add_argument('-seed', default=None,
help='The random seed to use.')
parser.add_argument('-teleport', action='store_true',
help='Uses unit costs')
parser.add_argument('-visualize', action='store_true',
help='When enabled, visualizes planning rather than the world (for debugging).')
args = parser.parse_args()
# TODO: could record the full trajectories here
world = World(use_gui=args.visualize)
world.open_gripper()
#joint_names = DRAWER_JOINTS + CABINET_JOINTS
joint_names = ZED_LEFT_JOINTS
print('Joints:', joint_names)
print('Knobs:', KNOBS)
wait_for_user('Start?')
for joint_name in joint_names:
collect_pull(world, joint_name, args)
for knob_name in KNOBS:
collect_pull(world, knob_name, args)
world.destroy()
if __name__ == '__main__':
main()