Skip to content

Commit

Permalink
Merge pull request #456 from Purg/dev/update-pose-node
Browse files Browse the repository at this point in the history
Update pose estimation node to use common structure
  • Loading branch information
Purg authored Oct 21, 2024
2 parents b579873 + 08a576b commit ac4b8f0
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 90 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
from pathlib import Path
from threading import Event, Lock, Thread
from typing import Union
import torch

# !!! FOR SOME REASON CURRENTLY UNKNOWN !!!
# An import of MMPose must happen first before any detectron2 imports.
# Otherwise, something about the ROS2 node handle creation fails with a
# double-free (no idea why...).
# No this import is not being used in this file. Yes it has to be here.
# Maybe there is something more specific in this import chain that is what is
# really necessary, but the investigation is at diminishing returns...
import mmpose.apis

from angel_system.utils.event import WaitAndClearEvent
import cv2
from cv_bridge import CvBridge
from geometry_msgs.msg import Point, Pose, Quaternion
import numpy as np
from rclpy.callback_groups import MutuallyExclusiveCallbackGroup, ReentrantCallbackGroup
from rclpy.node import Node, ParameterDescriptor, Parameter
from sensor_msgs.msg import Image

import cv2
from angel_system.utils.event import WaitAndClearEvent
from angel_system.utils.simple_timer import SimpleTimer
from tcn_hpl.data.utils.pose_generation.generate_pose_data import (
DETECTION_CLASSES,
PosesGenerator,
)

from angel_msgs.msg import (
ObjectDetection2dSet,
Expand All @@ -20,34 +30,11 @@
)
from angel_utils import declare_and_get_parameters, RateTracker # , DYNAMIC_TYPE
from angel_utils import make_default_main
from geometry_msgs.msg import Point, Pose, Quaternion

from tcn_hpl.data.utils.pose_generation.rt_pose_generation import predict_single
from mmpose.apis import init_pose_model

from detectron2.config import get_cfg
from tcn_hpl.data.utils.pose_generation.predictor import VisualizationDemo


BRIDGE = CvBridge()


def setup_detectron_cfg(config_file, model_checkpoint):
# load config from file and command-line arguments
cfg = get_cfg()
# To use demo for Panoptic-DeepLab, please uncomment the following two lines.
# from detectron2.projects.panoptic_deeplab import add_panoptic_deeplab_config # noqa
# add_panoptic_deeplab_config(cfg)
cfg.merge_from_file(config_file)
# Set score_threshold for builtin models
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = 0.8
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.8
cfg.MODEL.WEIGHTS = model_checkpoint
cfg.freeze()
return cfg


class PoseEstimator(Node):
"""
ROS node that runs the pose estimation model and outputs
Expand Down Expand Up @@ -76,12 +63,8 @@ def __init__(self):
("pose_config",),
##################################
# Defaulted parameters
("inference_img_size", 1280), # inference size (pixels)
("det_conf_threshold", 0.25), # object confidence threshold
("iou_threshold", 0.25), # IOU threshold for NMS
("det_conf_threshold", 0.75), # object confidence threshold
("cuda_device_id", 0), # cuda device: ID int or CPU
("no_trace", True), # don`t trace model
("agnostic_nms", False), # class-agnostic NMS
# Runtime thread checkin heartbeat interval in seconds.
("rt_thread_heartbeat", 0.1),
# If we should enable additional logging to the info level
Expand All @@ -100,56 +83,30 @@ def __init__(self):

self._ensure_image_resize = param_values["image_resize"]

self._inference_img_size = param_values["inference_img_size"]
self._det_conf_thresh = param_values["det_conf_threshold"]
self._iou_thr = param_values["iou_threshold"]
self._cuda_device_id = param_values["cuda_device_id"]
self._no_trace = param_values["no_trace"]
self._agnostic_nms = param_values["agnostic_nms"]

self.keypoints_cats = [
"nose",
"mouth",
"throat",
"chest",
"stomach",
"left_upper_arm",
"right_upper_arm",
"left_lower_arm",
"right_lower_arm",
"left_wrist",
"right_wrist",
"left_hand",
"right_hand",
"left_upper_leg",
"right_upper_leg",
"left_knee",
"right_knee",
"left_lower_leg",
"right_lower_leg",
"left_foot",
"right_foot",
"back",
]

print("finished setting params")

print("loading detectron model")
# Detectron Model
print(f"model_checkpoint: {self.det_model_ckpt_fp}")
detecron_cfg = setup_detectron_cfg(
self.det_config, model_checkpoint=self.det_model_ckpt_fp
)

self.det_model = VisualizationDemo(detecron_cfg)

print("loading pose model")
# Pose model
self.pose_model = init_pose_model(
self.pose_config, self.pose_model_ckpt_fp, device=self._cuda_device_id
print("Initializing pose models...")
# Encapsulates detection and pose models.
self.pose_gen = PosesGenerator(
self.det_config,
self.pose_config,
self._det_conf_thresh,
self.det_model_ckpt_fp,
self._cuda_device_id,
self.pose_model_ckpt_fp,
self._cuda_device_id,
)

self.device = torch.device(f"cuda:{self._cuda_device_id}")
print("Initializing pose models... Done")
self.keypoints_cats = [
v
for _, v in sorted(self.pose_gen.pose_dataset_info.keypoint_id2name.items())
]
# Pose estimates considered by this node is constrained to that of the
# patient class specifically.
self.patient_class_idx: int = DETECTION_CLASSES.index("patient")

self._enable_trace_logging = param_values["enable_time_trace_logging"]

Expand Down Expand Up @@ -254,24 +211,33 @@ def rt_loop(self):
# print(f"img0: {img0.shape}")
# height, width, chans = img0.shape

boxes, scores, classes, keypoints = self.pose_gen.predict_single(img0)

# Select keypoints for most confidence box among those with
# keypoints.
patient_results = np.argwhere(classes == self.patient_class_idx)
if patient_results.size > 0:
p_scores = scores[patient_results]
best_idx = patient_results[np.argmax(p_scores)][0]
assert (
keypoints[best_idx] is not None
), "Patient class should have keypoints but None were found"
keypoints = [keypoints[best_idx]]
else:
keypoints = []

all_poses_msg = HandJointPosesUpdate()
# note: setting metdata right before publishing below

boxes, labels, keypoints = predict_single(
det_model=self.det_model,
pose_model=self.pose_model,
image=img0,
bbox_thr=self._det_conf_thresh,
)

# at most, we have 1 set of keypoints for 1 patient
keypoints = keypoints[:1]
# at most, we have 1 set of keypoints for 1 "best" patient
for keypoints_ in keypoints:
for label, keypoint in zip(self.keypoints_cats, keypoints_):
position = Point()
position.x = float(keypoint[0])
position.y = float(keypoint[1])
position.z = float(keypoint[2])
position.z = float(
keypoint[2]
) # This is actually the confidence score

# Extract the orientation
orientation = Quaternion()
Expand Down

0 comments on commit ac4b8f0

Please sign in to comment.