Skip to content

Commit

Permalink
initial test working
Browse files Browse the repository at this point in the history
  • Loading branch information
ashay-bdai committed Sep 11, 2024
1 parent c832fb0 commit ea5f1bf
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 10 deletions.
3 changes: 2 additions & 1 deletion predicators/approaches/bilevel_planning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,14 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]:
seed = self._seed + self._num_calls
nsrts = self._get_current_nsrts()
preds = self._get_current_predicates()
import pdb; pdb.set_trace()
# utils.abstract(task.init, preds, self._vlm)

# Run task planning only and then greedily sample and execute in the
# policy.
if self._plan_without_sim:
nsrt_plan, atoms_seq, metrics = self._run_task_plan(
task, nsrts, preds, timeout, seed)
# import pdb; pdb.set_trace()
self._last_nsrt_plan = nsrt_plan
self._last_atoms_seq = atoms_seq
policy = utils.nsrt_plan_to_greedy_policy(nsrt_plan, task.goal,
Expand Down
3 changes: 2 additions & 1 deletion predicators/approaches/spot_wrapper_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def _policy(state: State) -> Action:
self._base_approach_has_control = True
# Need to call this once here to fix off-by-one issue.
atom_seq = self._base_approach.get_execution_monitoring_info()
assert all(a.holds(state) for a in atom_seq[0])
# TODO: consider reinstating the line below.
# assert all(a.holds(state) for a in atom_seq[0])
# Use the base policy.
return base_approach_policy(state)

Expand Down
30 changes: 30 additions & 0 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2564,8 +2564,38 @@ def reset(self, train_or_test: str, task_idx: int) -> Observation:

def step(self, action: Action) -> Observation:
assert self._robot is not None
action_name = action.extra_info.action_name
# Special case: the action is "done", indicating that the robot
# believes it has finished the task. Used for goal checking.
if action_name == "done":
while True:
goal_description = self._current_task.goal_description
logging.info(f"The goal is: {goal_description}")
prompt = "Is the goal accomplished? Answer y or n. "
response = utils.prompt_user(prompt).strip()
if response == "y":
self._current_task_goal_reached = True
break
if response == "n":
self._current_task_goal_reached = False
break
logging.info("Invalid input, must be either 'y' or 'n'")
return self._current_observation

# Execute the action in the real environment. Automatically retry
# if a retryable error is encountered.
action_fn = action.extra_info.real_world_fn
action_fn_args = action.extra_info.real_world_fn_args
while True:
try:
action_fn(*action_fn_args) # type: ignore
break
except RetryableRpcError as e:
logging.warning("WARNING: the following retryable error "
f"was encountered. Trying again.\n{e}")
rgbd_images = capture_images_without_context(self._robot)
gripper_open_percentage = get_robot_gripper_open_percentage(self._robot)
print(gripper_open_percentage)
objects_in_view = []
obs = _TruncatedSpotObservation(
rgbd_images,
Expand Down
2 changes: 1 addition & 1 deletion predicators/ground_truth_models/spot_env/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ def _move_to_ready_sweep_policy(state: State, memory: Dict,

def _teleop_policy(state: State, memory: Dict, objects: Sequence[Object], params: Array) -> Action:
del state, memory, params

robot, lease_client = get_robot_only()

def _teleop(robot: Robot, lease_client: LeaseClient):
Expand Down
5 changes: 4 additions & 1 deletion predicators/perception/spot_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,9 +673,12 @@ def step(self, observation: Observation) -> State:
self._robot = observation.robot
imgs = observation.rgbd_images
imgs = [v.rgb for _, v in imgs.items()]
# import PIL
# PIL.Image.fromarray(imgs[0]).show()
# import pdb; pdb.set_trace()
self._gripper_open_percentage = observation.gripper_open_percentage
self._curr_state = self._create_state()
self._curr_state.simulator_state["images"] = imgs
self._gripper_open_percentage = observation.gripper_open_percentage
ret_state = self._curr_state.copy()
return ret_state

Expand Down
2 changes: 1 addition & 1 deletion predicators/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ class GlobalSettings:

# parameters for vision language models
# gemini-1.5-pro-latest, gpt-4-turbo, gpt-4o
vlm_model_name = "gemini-1.5-pro-latest"
vlm_model_name = "gemini-1.5-flash" #"gemini-1.5-pro-latest"
vlm_temperature = 0.0
vlm_num_completions = 1
vlm_include_cropped_images = False
Expand Down
10 changes: 5 additions & 5 deletions predicators/spot_utils/perception/spot_cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
'right_fisheye_image': 180
}
RGB_TO_DEPTH_CAMERAS = {
"hand_color_image": "hand_depth_in_hand_color_frame",
"left_fisheye_image": "left_depth_in_visual_frame",
"right_fisheye_image": "right_depth_in_visual_frame",
# "hand_color_image": "hand_depth_in_hand_color_frame",
# "left_fisheye_image": "left_depth_in_visual_frame",
# "right_fisheye_image": "right_depth_in_visual_frame",
"frontleft_fisheye_image": "frontleft_depth_in_visual_frame",
"frontright_fisheye_image": "frontright_depth_in_visual_frame",
"back_fisheye_image": "back_depth_in_visual_frame"
# "frontright_fisheye_image": "frontright_depth_in_visual_frame",
# "back_fisheye_image": "back_depth_in_visual_frame"
}

# Hack to avoid double image capturing when we want to (1) get object states
Expand Down
1 change: 1 addition & 0 deletions predicators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2527,6 +2527,7 @@ def query_vlm_for_atom_vals(
num_completions=1)
assert len(vlm_output) == 1
vlm_output_str = vlm_output[0]
print(f"VLM output: {vlm_output_str}")
all_atom_queries = atom_queries_str.strip().split("\n")
all_vlm_responses = vlm_output_str.strip().split("\n")

Expand Down

0 comments on commit ea5f1bf

Please sign in to comment.