Skip to content

Commit

Permalink
Get camera annotation + object detection annotation to work on all im…
Browse files Browse the repository at this point in the history
…ages.
  • Loading branch information
ashay-bdai committed Sep 14, 2024
1 parent cf40334 commit 94dff09
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 102 deletions.
158 changes: 67 additions & 91 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ class _TruncatedSpotObservation:
# # A placeholder until all predicates have classifiers
# nonpercept_atoms: Set[GroundAtom]
# nonpercept_predicates: Set[Predicate]
# Object detections per camera in self.rgbd_images.
object_detections_per_camera: Dict[str, List[Tuple[ObjectDetectionID, SegmentedBoundingBox]]]


class _PartialPerceptionState(State):
Expand Down Expand Up @@ -2472,7 +2474,9 @@ def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]:
detection_id_to_obj: Dict[ObjectDetectionID, Object] = {}
objects = {
Object("pan", _movable_object_type),
Object("cup", _movable_object_type)
Object("cup", _movable_object_type),
Object("chair", _movable_object_type),
Object("bowl", _movable_object_type),
}
for o in objects:
detection_id = LanguageObjectDetectionID(o.name)
Expand Down Expand Up @@ -2558,6 +2562,18 @@ def __init__(self, use_gui: bool = True) -> None:
self._train_tasks = []
self._test_tasks = []

def detect_objects(self, rgbd_images: Dict[str, RGBDImage]) -> Dict[str, List[Tuple[ObjectDetectionID, SegmentedBoundingBox]]]:
object_ids = self._detection_id_to_obj.keys()
object_id_to_img_detections = _query_detic_sam2(object_ids, rgbd_images)
# This ^ is currently a mapping of object_id -> camera_name -> SegmentedBoundingBox.
# We want to do our annotations by camera image, so let's turn this into a
# mapping of camera_name -> object_id -> SegmentedBoundingBox.
detections = {k: [] for k in rgbd_images.keys()}
for object_id, d in object_id_to_img_detections.items():
for camera_name, seg_bb in d.items():
detections[camera_name].append((object_id, seg_bb))
return detections

def _actively_construct_env_task(self) -> EnvironmentTask:
assert self._robot is not None
rgbd_images = capture_images_without_context(self._robot)
Expand All @@ -2571,103 +2587,60 @@ def _actively_construct_env_task(self) -> EnvironmentTask:
objects_in_view = []

# Perform object detection.
object_ids = self._detection_id_to_obj.keys()
ret = _query_detic_sam2(object_ids, rgbd_images)
artifacts = {"language": {"rgbds": rgbd_images, "object_id_to_img_detections": ret}}
detections_outfile = Path(".") / "object_detection_artifacts.png"
no_detections_outfile = Path(".") / "no_detection_artifacts.png"
visualize_all_artifacts(artifacts, detections_outfile, no_detections_outfile)

# Draw object bounding box on images.
rgbds = artifacts["language"]["rgbds"]
detections = artifacts["language"]["object_id_to_img_detections"]
flat_detections: List[Tuple[RGBDImage,
LanguageObjectDetectionID,
SegmentedBoundingBox]] = []
for obj_id, img_detections in detections.items():
for camera, seg_bb in img_detections.items():
rgbd = rgbds[camera]
flat_detections.append((rgbd, obj_id, seg_bb))
object_detections_per_camera = self.detect_objects(rgbd_images)


# artifacts = {"language": {"rgbds": rgbd_images, "object_id_to_img_detections": ret}}
# detections_outfile = Path(".") / "object_detection_artifacts.png"
# no_detections_outfile = Path(".") / "no_detection_artifacts.png"
# visualize_all_artifacts(artifacts, detections_outfile, no_detections_outfile)

# # Draw object bounding box on images.
# rgbds = artifacts["language"]["rgbds"]
# detections = artifacts["language"]["object_id_to_img_detections"]
# flat_detections: List[Tuple[RGBDImage,
# LanguageObjectDetectionID,
# SegmentedBoundingBox]] = []
# for obj_id, img_detections in detections.items():
# for camera, seg_bb in img_detections.items():
# rgbd = rgbds[camera]
# flat_detections.append((rgbd, obj_id, seg_bb))

# For now assume we only have 1 image, front-left.
import pdb; pdb.set_trace()
import PIL
from PIL import ImageDraw, ImageFont
bb_pil_imgs = []
img = list(rgbd_images.values())[0].rotated_rgb
pil_img = PIL.Image.fromarray(img)
draw = ImageDraw.Draw(pil_img)
for i, (rgbd, obj_id, seg_bb) in enumerate(flat_detections):
# img = rgbd.rotated_rgb
# pil_img = PIL.Image.fromarray(img)
x0, y0, x1, y1 = seg_bb.bounding_box
draw.rectangle([(x0, y0), (x1, y1)], outline='green', width=2)
text = f"{obj_id.language_id}"
font = ImageFont.load_default()
# font = utils.get_scaled_default_font(draw, 4)
# text_width, text_height = draw.textsize(text, font)
# text_width = draw.textlength(text, font)
# text_height = font.getsize("hg")[1]
text_mask = font.getmask(text)
text_width, text_height = text_mask.size
text_bbox = [(x0, y0 - text_height - 2), (x0 + text_width + 2, y0)]
draw.rectangle(text_bbox, fill='green')
draw.text((x0 + 1, y0 - text_height - 1), text, fill='white', font=font)

import pdb; pdb.set_trace()



# box = seg_bb.bounding_box
# x0, y0 = box[0], box[1]
# w, h = box[2] - box[0], box[3] - box[1]
# ax_row[3].add_patch(
# plt.Rectangle((x0, y0),
# w,
# h,
# edgecolor='green',
# facecolor=(0, 0, 0, 0),
# lw=1))

# import PIL
# from PIL import ImageDraw
# annotated_pil_imgs = []
# for img, img_name in zip(imgs, img_names):
# pil_img = PIL.Image.fromarray(img)
# draw = ImageDraw.Draw(pil_img)
# font = utils.get_scaled_default_font(draw, 4)
# annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font)
# annotated_pil_imgs.append(pil_img)
# annotated_imgs = [np.array(img) for img in annotated_pil_imgs]

# im = Image.open(image_path)
# draw = ImageDraw.Draw(im)
# font = ImageFont.load_default() # You can use a specific font if needed

# for mask in masks:
# # Assuming you have a function to convert the mask to a PIL Image or polygon
# mask_image = convert_mask_to_pil(mask)
# im.paste(mask_image, (0, 0), mask_image)

# for box, class_name, score in zip(input_boxes, classes, scores):
# x0, y0, x1, y1 = box
# draw.rectangle([(x0, y0), (x1, y1)], outline='green', width=2)
# text = f"{class_name}: {score:.2f}"
# text_width, text_height = draw.textsize(text, font)
# text_bbox = [(x0, y0 - text_height - 2), (x0 + text_width + 2, y0)]
# draw.rectangle(text_bbox, fill='green')
# draw.text((x0 + 1, y0 - text_height - 1), text, fill='white', font=font)

# im.show() # Or save it: im.save("output.jpg")
# import pdb; pdb.set_trace()
# # For now assume we only have 1 image, front-left.
# import pdb; pdb.set_trace()
# import PIL
# from PIL import ImageDraw, ImageFont
# bb_pil_imgs = []
# img = list(rgbd_images.values())[0].rotated_rgb
# pil_img = PIL.Image.fromarray(img)
# draw = ImageDraw.Draw(pil_img)
# for i, (rgbd, obj_id, seg_bb) in enumerate(flat_detections):
# # img = rgbd.rotated_rgb
# # pil_img = PIL.Image.fromarray(img)
# x0, y0, x1, y1 = seg_bb.bounding_box
# draw.rectangle([(x0, y0), (x1, y1)], outline='green', width=2)
# text = f"{obj_id.language_id}"
# font = ImageFont.load_default()
# # font = utils.get_scaled_default_font(draw, 4)
# # text_width, text_height = draw.textsize(text, font)
# # text_width = draw.textlength(text, font)
# # text_height = font.getsize("hg")[1]
# text_mask = font.getmask(text)
# text_width, text_height = text_mask.size
# text_bbox = [(x0, y0 - text_height - 2), (x0 + text_width + 2, y0)]
# draw.rectangle(text_bbox, fill='green')
# draw.text((x0 + 1, y0 - text_height - 1), text, fill='white', font=font)

# import pdb; pdb.set_trace()

obs = _TruncatedSpotObservation(
rgbd_images,
set(objects_in_view),
set(),
set(),
self._spot_object,
gripper_open_percentage
gripper_open_percentage,
object_detections_per_camera
)
goal_description = self._generate_goal_description()
task = EnvironmentTask(obs, goal_description)
Expand Down Expand Up @@ -2728,14 +2701,17 @@ def step(self, action: Action) -> Observation:
rgbd_images = capture_images_without_context(self._robot)
gripper_open_percentage = get_robot_gripper_open_percentage(self._robot)
objects_in_view = []
# Perform object detection.
object_detections_per_camera = self.detect_objects(rgbd_images)

obs = _TruncatedSpotObservation(
rgbd_images,
set(objects_in_view),
set(),
set(),
self._spot_object,
gripper_open_percentage
gripper_open_percentage,
object_detections_per_camera
)
return obs

Expand Down
36 changes: 32 additions & 4 deletions predicators/perception/spot_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,12 +683,40 @@ def reset(self, env_task: EnvironmentTask) -> Task:
return Task(state, goal)

def step(self, observation: Observation) -> State:
import pdb; pdb.set_trace()
self._waiting_for_observation = False
self._robot = observation.robot
imgs = observation.rgbd_images
img_names = [v.camera_name for _, v in imgs.items()]
imgs = [v.rgb for _, v in imgs.items()]
# import pdb; pdb.set_trace()
img_objects = observation.rgbd_images # RGBDImage objects
img_names = [v.camera_name for _, v in img_objects.items()]
imgs = [v.rotated_rgb for _, v in img_objects.items()]
import PIL
from PIL import ImageDraw, ImageFont
pil_imgs = [PIL.Image.fromarray(img) for img in imgs]
# Annotate images with detected objects (names + bounding box)
# and camera name.
object_detections_per_camera = observation.object_detections_per_camera
imgs_with_objects_annotated = [] # These are PIL images.
for i, camera_name in enumerate(img_names):
draw = ImageDraw.Draw(pil_imgs[i])
# Annotate with camera name.
font = utils.get_scaled_default_font(draw, 4)
_ = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[camera_name], font)
# Annotate with object detections.
detections = object_detections_per_camera[camera_name]
for obj_id, seg_bb in detections:
x0, y0, x1, y1 = seg_bb.bounding_box
draw.rectangle([(x0, y0), (x1, y1)], outline='green', width=2)
text = f"{obj_id.language_id}"
font = ImageFont.load_default()
text_mask = font.getmask(text)
text_width, text_height = text_mask.size
text_bbox = [(x0, y0 - text_height - 2), (x0 + text_width + 2, y0)]
draw.rectangle(text_bbox, fill='green')
draw.text((x0 + 1, y0 - text_height - 1), text, fill='white', font=font)

import pdb; pdb.set_trace()


import PIL
from PIL import ImageDraw
annotated_pil_imgs = []
Expand Down
2 changes: 1 addition & 1 deletion predicators/spot_utils/perception/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def _query_detic_sam2(
scores[best_idx])
object_id_to_img_detections[obj_id][rgbd.camera_name] = seg_bb

import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
return object_id_to_img_detections


Expand Down
12 changes: 6 additions & 6 deletions predicators/spot_utils/perception/spot_cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
'right_fisheye_image': 180
}
RGB_TO_DEPTH_CAMERAS = {
# "hand_color_image": "hand_depth_in_hand_color_frame",
# "left_fisheye_image": "left_depth_in_visual_frame",
# "right_fisheye_image": "right_depth_in_visual_frame",
"hand_color_image": "hand_depth_in_hand_color_frame",
"left_fisheye_image": "left_depth_in_visual_frame",
"right_fisheye_image": "right_depth_in_visual_frame",
"frontleft_fisheye_image": "frontleft_depth_in_visual_frame",
# "frontright_fisheye_image": "frontright_depth_in_visual_frame",
# "back_fisheye_image": "back_depth_in_visual_frame"
"frontright_fisheye_image": "frontright_depth_in_visual_frame",
"back_fisheye_image": "back_depth_in_visual_frame"
}

# Hack to avoid double image capturing when we want to (1) get object states
Expand Down Expand Up @@ -125,7 +125,7 @@ def capture_images_without_context(
robot: Robot,
camera_names: Optional[Collection[str]] = None,
quality_percent: int = 100,
) -> Dict[str, RGBDImageWithContext]:
) -> Dict[str, RGBDImage]:
"""Build an image request and get the responses.
If no camera names are provided, all RGB cameras are used.
Expand Down

0 comments on commit 94dff09

Please sign in to comment.