Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
andipeng committed May 16, 2024
1 parent a8151a2 commit 1f7f524
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 14 deletions.
53 changes: 53 additions & 0 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3079,3 +3079,56 @@ def _generate_goal_description(self) -> GoalDescription:
def _get_dry_task(self, train_or_test: str,
task_idx: int) -> EnvironmentTask:
raise NotImplementedError("Dry task generation not implemented.")


###############################################################################
# Test plant demo #
###############################################################################


class TestPlantEnv(SpotRearrangementEnv):
"""TODO; basic demo
"""

def __init__(self, use_gui: bool = True) -> None:
super().__init__(use_gui)

op_to_name = {o.name: o for o in _create_operators()}
op_names_to_keep = {
"MoveToReachObject",
"MoveToHandViewObject",
"PickObjectFromTop",
"PlaceObjectOnTop",
}
self._strips_operators = {op_to_name[o] for o in op_names_to_keep}

@classmethod
def get_name(cls) -> str:
return "plant_test_env"

@property
def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]:

detection_id_to_obj: Dict[ObjectDetectionID, Object] = {}

green_apple = Object("green_apple", _movable_object_type)
green_apple_detection = LanguageObjectDetectionID(
"green apple/tennis ball")
detection_id_to_obj[green_apple_detection] = green_apple
plant = Object("plant", _immovable_object_type)
plant_detection = LanguageObjectDetectionID(
"potted plant")
detection_id_to_obj[plant_detection] = plant

for obj, pose in get_known_immovable_objects().items():
detection_id = KnownStaticObjectDetectionID(obj.name, pose)
detection_id_to_obj[detection_id] = obj

return detection_id_to_obj

def _generate_goal_description(self) -> GoalDescription:
return "place the green apple on the plant"

def _get_dry_task(self, train_or_test: str,
task_idx: int) -> EnvironmentTask:
raise NotImplementedError("Dry task generation not implemented.")
2 changes: 1 addition & 1 deletion predicators/ground_truth_models/spot_env/nsrts.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def get_env_names(cls) -> Set[str]:
"spot_cube_env", "spot_soda_floor_env", "spot_soda_table_env",
"spot_soda_bucket_env", "spot_soda_chair_env",
"spot_main_sweep_env", "spot_ball_and_cup_sticky_table_env",
"spot_brush_shelf_env", "lis_spot_block_floor_env"
"spot_brush_shelf_env", "lis_spot_block_floor_env", "plant_test_env"
}

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions predicators/ground_truth_models/spot_env/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,7 @@ def get_env_names(cls) -> Set[str]:
"spot_ball_and_cup_sticky_table_env",
"spot_brush_shelf_env",
"lis_spot_block_floor_env",
"plant_test_env"
}

@classmethod
Expand Down
29 changes: 18 additions & 11 deletions predicators/perception/spot_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,17 +269,17 @@ def _create_state(self) -> State:
}

# Uncomment for debugging.
# logging.info("Percept state:")
# logging.info(percept_state.pretty_str())
# logging.info("Percept atoms:")
# atom_str = "\n".join(
# map(
# str,
# sorted(utils.abstract(percept_state,
# self._percept_predicates))))
# logging.info(atom_str)
# logging.info("Simulator state:")
# logging.info(simulator_state)
logging.info("Percept state:")
logging.info(percept_state.pretty_str())
logging.info("Percept atoms:")
atom_str = "\n".join(
map(
str,
sorted(utils.abstract(percept_state,
self._percept_predicates))))
logging.info(atom_str)
logging.info("Simulator state:")
logging.info(simulator_state)

# Now finish the state.
state = _PartialPerceptionState(percept_state.data,
Expand Down Expand Up @@ -500,6 +500,13 @@ def _create_goal(self, state: State,
GroundAtom(ContainerReadyForSweeping, [bucket, black_table]),
GroundAtom(IsSweeper, [brush])
}
if goal_description == "place the green apple on the plant":
plant = Object("plant", _immovable_object_type)
apple = Object("green_apple", _movable_object_type)
On = pred_name_to_pred["On"]
return {
GroundAtom(On, [apple, plant]),
}
raise NotImplementedError("Unrecognized goal description")

def render_mental_images(self, observation: Observation,
Expand Down
12 changes: 10 additions & 2 deletions predicators/spot_utils/graph_nav_maps/b45-621/metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,18 @@ static-object-features:
length: 10000000 # effectively infinite
width: 10000000
flat_top_surface: 1
red_block:
green_apple:
shape: 2
height: 0.1
length: 0.1
width: 0.1
placeable: 1
is_sweeper: 0
is_sweeper: 0
plant:
shape: 2
height: 0.25
length: 0.25
width: 0.25
placeable: 1
is_sweeper: 0
flat_top_surface: 1
1 change: 1 addition & 0 deletions predicators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ def sample_random_point(
self,
rng: np.random.Generator,
min_dist_from_edge: float = 0.0) -> Tuple[float, float]:
import ipdb; ipdb.set_trace()
assert min_dist_from_edge < self.radius, "min_dist_from_edge is " + \
"greater than radius"
rand_mag = rng.uniform(0, self.radius - min_dist_from_edge)
Expand Down

0 comments on commit 1f7f524

Please sign in to comment.