From b00c58697c16ccdba23659522ca07c7bf1580217 Mon Sep 17 00:00:00 2001 From: Ivan Varela Date: Thu, 30 May 2024 11:33:01 +0100 Subject: [PATCH 1/6] Added from_mmp_file function --- movement/io/load_poses.py | 100 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index e2b309109..af956ba3b 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -1,5 +1,6 @@ """Functions for loading pose tracking data from various frameworks.""" +import json import logging from pathlib import Path from typing import Literal, Optional, Union @@ -678,3 +679,102 @@ def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset: "source_file": None, }, ) + + +def from_mmp_file( + file_path: Union[Path, str], fps: Optional[float] = None +) -> xr.Dataset: + """Load pose tracking data from a JSON file with multiple individuals. + + This function expects JSON data in the following format: + [ + { + "frame_id": int, + "instances": [ + { + "keypoints": list[list[float]], + "keypoint_scores": list[float], + "bbox": list[float], + "bbox_score": float + }, + ... + ] + }, + ... + ] + + Parameters + ---------- + file_path : pathlib.Path or str + Path to the JSON file containing the pose tracking data. + fps : float, optional + The number of frames per second in the video. If None (default), + the 'time' coordinates will be in frame numbers. + + Returns + ------- + xr.Dataset + Dataset containing the pose tracks, confidence scores, and metadata. + + """ + file = ValidFile( + file_path, + expected_permission="r", + expected_suffix=[".json"], + ) + + # file = ValidMMPoseJSON(file_path) + + with open(file.path) as f: + data = json.load(f) + + all_tracks = [] + all_scores = [] + n_keypoints = len( + data[0]["instances"][0]["keypoints"] + ) # Keypoints in the first instance of the first frame + + for _, frame in enumerate(data): + # Initialize arrays for this frame's tracks and scores + n_individuals = len(frame["instances"]) + frame_tracks = np.full( + (n_individuals, n_keypoints, 2), np.nan, dtype=np.float32 + ) + frame_scores = np.full( + (n_individuals, n_keypoints), np.nan, dtype=np.float32 + ) + + for i, instance in enumerate(frame["instances"]): + frame_tracks[i] = np.array(instance["keypoints"])[:, :2] # (x, y) + frame_scores[i] = instance["keypoint_scores"] + + all_tracks.append(frame_tracks) + all_scores.append(frame_scores) + + # Stack the frames to get a 3D array + tracks_array = np.stack(all_tracks, axis=0) + scores_array = np.stack(all_scores, axis=0) + + keypoint_names = [ + f"keypoint_{i}" for i in range(n_keypoints) + ] # Use pre-calculated keypoint names + + # Create ValidPosesDataset and convert to xarray.Dataset + valid_data = ValidPosesDataset( + position_array=tracks_array, + confidence_array=scores_array, + individual_names=[ + f"individual_{i}" for i in range(tracks_array.shape[1]) + ], + keypoint_names=keypoint_names, + fps=fps, + ) + ds = _from_valid_data(valid_data) + + # Metadata + ds.attrs["source_software"] = "JSON" + ds.attrs["source_file"] = file.path.as_posix() + + logger.info(f"Loaded pose tracks from {file.path}:") + logger.info(ds) + return ds From f4983d7b4f82cd240005d28202c6fe771f6b0373 Mon Sep 17 00:00:00 2001 From: Ivan Varela Date: Tue, 4 Jun 2024 13:25:28 +0100 Subject: [PATCH 2/6] changed _ds_from_valid_data call name --- movement/io/load_poses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index af956ba3b..29f2ab80e 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -769,7 +769,7 @@ def from_mmp_file( keypoint_names=keypoint_names, fps=fps, ) - ds = _from_valid_data(valid_data) + ds = _ds_from_valid_data(valid_data) # Metadata ds.attrs["source_software"] = "JSON" From c505047a355c1f0aa284e8448ae7f15d5b6c1a73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20V?= <131998802+ivanvrlg@users.noreply.github.com> Date: Tue, 23 Jul 2024 13:34:49 +0100 Subject: [PATCH 3/6] Update movement/io/load_poses.py Co-authored-by: Niko Sirmpilatze --- movement/io/load_poses.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index 29f2ab80e..25d93c876 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -713,9 +713,9 @@ def from_mmp_file( Returns ------- - xr.Dataset - Dataset containing the pose tracks, confidence scores, and metadata. - + xarray.Dataset + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata. """ file = ValidFile( file_path, From 3bc1c0c69a4856a018139e5ebcdfe0f4a9bd7a44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20V?= <131998802+ivanvrlg@users.noreply.github.com> Date: Tue, 23 Jul 2024 13:34:57 +0100 Subject: [PATCH 4/6] Update movement/io/load_poses.py Co-authored-by: Niko Sirmpilatze --- movement/io/load_poses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index 25d93c876..1ad45d5c7 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -686,7 +686,7 @@ def from_mmp_file( ) -> xr.Dataset: """Load pose tracking data from a JSON file with multiple individuals. - This function expects JSON data in the following format: + Path to the .json file containing the pose tracking data. [ { "frame_id": int, From 4c4279284e852b84c5f170cdbe7bcd1d883d725c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Jul 2024 12:35:04 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- movement/io/load_poses.py | 1 + 1 file changed, 1 insertion(+) diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index 1ad45d5c7..4496d5bdc 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -716,6 +716,7 @@ def from_mmp_file( xarray.Dataset ``movement`` dataset containing the pose tracks, confidence scores, and associated metadata. + """ file = ValidFile( file_path, From a02f2486d4feafede30c9b3e0002819b08a2597f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20V?= <131998802+ivanvrlg@users.noreply.github.com> Date: Tue, 23 Jul 2024 13:35:14 +0100 Subject: [PATCH 6/6] Update movement/io/load_poses.py Co-authored-by: Niko Sirmpilatze --- movement/io/load_poses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index 4496d5bdc..c551d925c 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -684,7 +684,7 @@ def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset: def from_mmp_file( file_path: Union[Path, str], fps: Optional[float] = None ) -> xr.Dataset: - """Load pose tracking data from a JSON file with multiple individuals. + """Create a ``movement`` dataset from an MMPose .json file Path to the .json file containing the pose tracking data. [