Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update pose estimation node to use common structure #456

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
from pathlib import Path
from threading import Event, Lock, Thread
from typing import Union
import torch

# !!! FOR SOME REASON CURRENTLY UNKNOWN !!!
# An import of MMPose must happen first before any detectron2 imports.
# Otherwise, something about the ROS2 node handle creation fails with a
# double-free (no idea why...).
# No this import is not being used in this file. Yes it has to be here.
# Maybe there is something more specific in this import chain that is what is
# really necessary, but the investigation is at diminishing returns...
import mmpose.apis

from angel_system.utils.event import WaitAndClearEvent
import cv2
from cv_bridge import CvBridge
from geometry_msgs.msg import Point, Pose, Quaternion
import numpy as np
from rclpy.callback_groups import MutuallyExclusiveCallbackGroup, ReentrantCallbackGroup
from rclpy.node import Node, ParameterDescriptor, Parameter
from sensor_msgs.msg import Image

import cv2
from angel_system.utils.event import WaitAndClearEvent
from angel_system.utils.simple_timer import SimpleTimer
from tcn_hpl.data.utils.pose_generation.generate_pose_data import (
DETECTION_CLASSES,
PosesGenerator,
)

from angel_msgs.msg import (
ObjectDetection2dSet,
Expand All @@ -20,34 +30,11 @@
)
from angel_utils import declare_and_get_parameters, RateTracker # , DYNAMIC_TYPE
from angel_utils import make_default_main
from geometry_msgs.msg import Point, Pose, Quaternion

from tcn_hpl.data.utils.pose_generation.rt_pose_generation import predict_single
from mmpose.apis import init_pose_model

from detectron2.config import get_cfg
from tcn_hpl.data.utils.pose_generation.predictor import VisualizationDemo


BRIDGE = CvBridge()


def setup_detectron_cfg(config_file, model_checkpoint):
# load config from file and command-line arguments
cfg = get_cfg()
# To use demo for Panoptic-DeepLab, please uncomment the following two lines.
# from detectron2.projects.panoptic_deeplab import add_panoptic_deeplab_config # noqa
# add_panoptic_deeplab_config(cfg)
cfg.merge_from_file(config_file)
# Set score_threshold for builtin models
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = 0.8
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.8
cfg.MODEL.WEIGHTS = model_checkpoint
cfg.freeze()
return cfg


class PoseEstimator(Node):
"""
ROS node that runs the pose estimation model and outputs
Expand Down Expand Up @@ -76,12 +63,8 @@ def __init__(self):
("pose_config",),
##################################
# Defaulted parameters
("inference_img_size", 1280), # inference size (pixels)
("det_conf_threshold", 0.25), # object confidence threshold
("iou_threshold", 0.25), # IOU threshold for NMS
("det_conf_threshold", 0.75), # object confidence threshold
("cuda_device_id", 0), # cuda device: ID int or CPU
("no_trace", True), # don`t trace model
("agnostic_nms", False), # class-agnostic NMS
# Runtime thread checkin heartbeat interval in seconds.
("rt_thread_heartbeat", 0.1),
# If we should enable additional logging to the info level
Expand All @@ -100,56 +83,30 @@ def __init__(self):

self._ensure_image_resize = param_values["image_resize"]

self._inference_img_size = param_values["inference_img_size"]
self._det_conf_thresh = param_values["det_conf_threshold"]
self._iou_thr = param_values["iou_threshold"]
self._cuda_device_id = param_values["cuda_device_id"]
self._no_trace = param_values["no_trace"]
self._agnostic_nms = param_values["agnostic_nms"]

self.keypoints_cats = [
"nose",
"mouth",
"throat",
"chest",
"stomach",
"left_upper_arm",
"right_upper_arm",
"left_lower_arm",
"right_lower_arm",
"left_wrist",
"right_wrist",
"left_hand",
"right_hand",
"left_upper_leg",
"right_upper_leg",
"left_knee",
"right_knee",
"left_lower_leg",
"right_lower_leg",
"left_foot",
"right_foot",
"back",
]

print("finished setting params")

print("loading detectron model")
# Detectron Model
print(f"model_checkpoint: {self.det_model_ckpt_fp}")
detecron_cfg = setup_detectron_cfg(
self.det_config, model_checkpoint=self.det_model_ckpt_fp
)

self.det_model = VisualizationDemo(detecron_cfg)

print("loading pose model")
# Pose model
self.pose_model = init_pose_model(
self.pose_config, self.pose_model_ckpt_fp, device=self._cuda_device_id
print("Initializing pose models...")
# Encapsulates detection and pose models.
self.pose_gen = PosesGenerator(
self.det_config,
self.pose_config,
self._det_conf_thresh,
self.det_model_ckpt_fp,
self._cuda_device_id,
self.pose_model_ckpt_fp,
self._cuda_device_id,
)

self.device = torch.device(f"cuda:{self._cuda_device_id}")
print("Initializing pose models... Done")
self.keypoints_cats = [
v
for _, v in sorted(self.pose_gen.pose_dataset_info.keypoint_id2name.items())
]
# Pose estimates considered by this node is constrained to that of the
# patient class specifically.
self.patient_class_idx: int = DETECTION_CLASSES.index("patient")

self._enable_trace_logging = param_values["enable_time_trace_logging"]

Expand Down Expand Up @@ -254,24 +211,33 @@ def rt_loop(self):
# print(f"img0: {img0.shape}")
# height, width, chans = img0.shape

boxes, scores, classes, keypoints = self.pose_gen.predict_single(img0)

# Select keypoints for most confidence box among those with
# keypoints.
patient_results = np.argwhere(classes == self.patient_class_idx)
if patient_results.size > 0:
p_scores = scores[patient_results]
best_idx = patient_results[np.argmax(p_scores)][0]
assert (
keypoints[best_idx] is not None
), "Patient class should have keypoints but None were found"
keypoints = [keypoints[best_idx]]
else:
keypoints = []

all_poses_msg = HandJointPosesUpdate()
# note: setting metdata right before publishing below

boxes, labels, keypoints = predict_single(
det_model=self.det_model,
pose_model=self.pose_model,
image=img0,
bbox_thr=self._det_conf_thresh,
)

# at most, we have 1 set of keypoints for 1 patient
keypoints = keypoints[:1]
# at most, we have 1 set of keypoints for 1 "best" patient
for keypoints_ in keypoints:
for label, keypoint in zip(self.keypoints_cats, keypoints_):
position = Point()
position.x = float(keypoint[0])
position.y = float(keypoint[1])
position.z = float(keypoint[2])
position.z = float(
keypoint[2]
) # This is actually the confidence score

# Extract the orientation
orientation = Quaternion()
Expand Down
Loading