From 08a576b2576739f2a7e081dc245215dbb0266c9e Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Thu, 17 Oct 2024 18:11:07 -0400 Subject: [PATCH] Update pose estimation node to use common structure Updated TCN HPL repo to get at update. --- python-tpl/TCN_HPL | 2 +- .../pose_estimation/pose_estimator.py | 144 +++++++----------- 2 files changed, 56 insertions(+), 90 deletions(-) diff --git a/python-tpl/TCN_HPL b/python-tpl/TCN_HPL index 4c3effab8..e98c333fd 160000 --- a/python-tpl/TCN_HPL +++ b/python-tpl/TCN_HPL @@ -1 +1 @@ -Subproject commit 4c3effab83cde089028a12f43cdd7d9a4557b2b0 +Subproject commit e98c333fdba6f709e9c477e4ac284c0fb906e206 diff --git a/ros/angel_system_nodes/angel_system_nodes/pose_estimation/pose_estimator.py b/ros/angel_system_nodes/angel_system_nodes/pose_estimation/pose_estimator.py index b4d239f4f..3fae25a0a 100644 --- a/ros/angel_system_nodes/angel_system_nodes/pose_estimation/pose_estimator.py +++ b/ros/angel_system_nodes/angel_system_nodes/pose_estimation/pose_estimator.py @@ -1,16 +1,26 @@ -from pathlib import Path from threading import Event, Lock, Thread -from typing import Union -import torch + +# !!! FOR SOME REASON CURRENTLY UNKNOWN !!! +# An import of MMPose must happen first before any detectron2 imports. +# Otherwise, something about the ROS2 node handle creation fails with a +# double-free (no idea why...). +# No this import is not being used in this file. Yes it has to be here. +# Maybe there is something more specific in this import chain that is what is +# really necessary, but the investigation is at diminishing returns... +import mmpose.apis + +from angel_system.utils.event import WaitAndClearEvent +import cv2 from cv_bridge import CvBridge +from geometry_msgs.msg import Point, Pose, Quaternion import numpy as np from rclpy.callback_groups import MutuallyExclusiveCallbackGroup, ReentrantCallbackGroup from rclpy.node import Node, ParameterDescriptor, Parameter from sensor_msgs.msg import Image - -import cv2 -from angel_system.utils.event import WaitAndClearEvent -from angel_system.utils.simple_timer import SimpleTimer +from tcn_hpl.data.utils.pose_generation.generate_pose_data import ( + DETECTION_CLASSES, + PosesGenerator, +) from angel_msgs.msg import ( ObjectDetection2dSet, @@ -20,34 +30,11 @@ ) from angel_utils import declare_and_get_parameters, RateTracker # , DYNAMIC_TYPE from angel_utils import make_default_main -from geometry_msgs.msg import Point, Pose, Quaternion - -from tcn_hpl.data.utils.pose_generation.rt_pose_generation import predict_single -from mmpose.apis import init_pose_model - -from detectron2.config import get_cfg -from tcn_hpl.data.utils.pose_generation.predictor import VisualizationDemo BRIDGE = CvBridge() -def setup_detectron_cfg(config_file, model_checkpoint): - # load config from file and command-line arguments - cfg = get_cfg() - # To use demo for Panoptic-DeepLab, please uncomment the following two lines. - # from detectron2.projects.panoptic_deeplab import add_panoptic_deeplab_config # noqa - # add_panoptic_deeplab_config(cfg) - cfg.merge_from_file(config_file) - # Set score_threshold for builtin models - cfg.MODEL.RETINANET.SCORE_THRESH_TEST = 0.8 - cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8 - cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.8 - cfg.MODEL.WEIGHTS = model_checkpoint - cfg.freeze() - return cfg - - class PoseEstimator(Node): """ ROS node that runs the pose estimation model and outputs @@ -76,12 +63,8 @@ def __init__(self): ("pose_config",), ################################## # Defaulted parameters - ("inference_img_size", 1280), # inference size (pixels) - ("det_conf_threshold", 0.25), # object confidence threshold - ("iou_threshold", 0.25), # IOU threshold for NMS + ("det_conf_threshold", 0.75), # object confidence threshold ("cuda_device_id", 0), # cuda device: ID int or CPU - ("no_trace", True), # don`t trace model - ("agnostic_nms", False), # class-agnostic NMS # Runtime thread checkin heartbeat interval in seconds. ("rt_thread_heartbeat", 0.1), # If we should enable additional logging to the info level @@ -100,56 +83,30 @@ def __init__(self): self._ensure_image_resize = param_values["image_resize"] - self._inference_img_size = param_values["inference_img_size"] self._det_conf_thresh = param_values["det_conf_threshold"] - self._iou_thr = param_values["iou_threshold"] self._cuda_device_id = param_values["cuda_device_id"] - self._no_trace = param_values["no_trace"] - self._agnostic_nms = param_values["agnostic_nms"] - - self.keypoints_cats = [ - "nose", - "mouth", - "throat", - "chest", - "stomach", - "left_upper_arm", - "right_upper_arm", - "left_lower_arm", - "right_lower_arm", - "left_wrist", - "right_wrist", - "left_hand", - "right_hand", - "left_upper_leg", - "right_upper_leg", - "left_knee", - "right_knee", - "left_lower_leg", - "right_lower_leg", - "left_foot", - "right_foot", - "back", - ] print("finished setting params") - print("loading detectron model") - # Detectron Model - print(f"model_checkpoint: {self.det_model_ckpt_fp}") - detecron_cfg = setup_detectron_cfg( - self.det_config, model_checkpoint=self.det_model_ckpt_fp - ) - - self.det_model = VisualizationDemo(detecron_cfg) - - print("loading pose model") - # Pose model - self.pose_model = init_pose_model( - self.pose_config, self.pose_model_ckpt_fp, device=self._cuda_device_id + print("Initializing pose models...") + # Encapsulates detection and pose models. + self.pose_gen = PosesGenerator( + self.det_config, + self.pose_config, + self._det_conf_thresh, + self.det_model_ckpt_fp, + self._cuda_device_id, + self.pose_model_ckpt_fp, + self._cuda_device_id, ) - - self.device = torch.device(f"cuda:{self._cuda_device_id}") + print("Initializing pose models... Done") + self.keypoints_cats = [ + v + for _, v in sorted(self.pose_gen.pose_dataset_info.keypoint_id2name.items()) + ] + # Pose estimates considered by this node is constrained to that of the + # patient class specifically. + self.patient_class_idx: int = DETECTION_CLASSES.index("patient") self._enable_trace_logging = param_values["enable_time_trace_logging"] @@ -254,24 +211,33 @@ def rt_loop(self): # print(f"img0: {img0.shape}") # height, width, chans = img0.shape + boxes, scores, classes, keypoints = self.pose_gen.predict_single(img0) + + # Select keypoints for most confidence box among those with + # keypoints. + patient_results = np.argwhere(classes == self.patient_class_idx) + if patient_results.size > 0: + p_scores = scores[patient_results] + best_idx = patient_results[np.argmax(p_scores)][0] + assert ( + keypoints[best_idx] is not None + ), "Patient class should have keypoints but None were found" + keypoints = [keypoints[best_idx]] + else: + keypoints = [] + all_poses_msg = HandJointPosesUpdate() # note: setting metdata right before publishing below - boxes, labels, keypoints = predict_single( - det_model=self.det_model, - pose_model=self.pose_model, - image=img0, - bbox_thr=self._det_conf_thresh, - ) - - # at most, we have 1 set of keypoints for 1 patient - keypoints = keypoints[:1] + # at most, we have 1 set of keypoints for 1 "best" patient for keypoints_ in keypoints: for label, keypoint in zip(self.keypoints_cats, keypoints_): position = Point() position.x = float(keypoint[0]) position.y = float(keypoint[1]) - position.z = float(keypoint[2]) + position.z = float( + keypoint[2] + ) # This is actually the confidence score # Extract the orientation orientation = Quaternion()