diff --git a/examples/nwb_conversion.py b/examples/nwb_conversion.py new file mode 100644 index 000000000..f0001b32f --- /dev/null +++ b/examples/nwb_conversion.py @@ -0,0 +1,55 @@ +"""Converting movement dataset to NWB or loading from NWB to movement dataset. +============================ + +Export pose tracks to NWB +""" + +# %% Load the sample data +import datetime + +from pynwb import NWBHDF5IO, NWBFile + +from movement import sample_data +from movement.io.nwb import ( + add_movement_dataset_to_nwb, + convert_nwb_to_movement, +) + +ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv") + +# %%The dataset has two individuals. +# We will create two NWBFiles for each individual + +session_start_time = datetime.datetime.now(datetime.timezone.utc) +nwbfile_individual1 = NWBFile( + session_description="session_description", + identifier="individual1", + session_start_time=session_start_time, +) +nwbfile_individual2 = NWBFile( + session_description="session_description", + identifier="individual2", + session_start_time=session_start_time, +) + +nwbfiles = [nwbfile_individual1, nwbfile_individual2] + +# %% Convert the dataset to NWB +# This will create PoseEstimation and Skeleton objects for each +# individual and add them to the NWBFile +add_movement_dataset_to_nwb(nwbfiles, ds) + +# %% Save the NWBFiles +for file in nwbfiles: + with NWBHDF5IO(f"{file.identifier}.nwb", "w") as io: + io.write(file) + +# %% Convert the NWBFiles back to a movement dataset +# This will create a movement dataset with the same data as +# the original dataset from the NWBFiles + +# Convert the NWBFiles to a movement dataset +ds_from_nwb = convert_nwb_to_movement( + nwb_filepaths=["individual1.nwb", "individual2.nwb"] +) +ds_from_nwb diff --git a/movement/io/nwb.py b/movement/io/nwb.py new file mode 100644 index 000000000..5a7284904 --- /dev/null +++ b/movement/io/nwb.py @@ -0,0 +1,293 @@ +"""Functions to convert movement data to and from NWB format.""" + +from pathlib import Path + +import ndx_pose +import numpy as np +import pynwb +import xarray as xr + +from movement.logging import log_error + + +def _create_pose_and_skeleton_objects( + ds: xr.Dataset, + subject: str, + pose_estimation_series_kwargs: dict | None = None, + pose_estimation_kwargs: dict | None = None, + skeleton_kwargs: dict | None = None, +) -> tuple[list[ndx_pose.PoseEstimation], ndx_pose.Skeletons]: + """Create PoseEstimation and Skeletons objects from a ``movement`` dataset. + + Parameters + ---------- + ds : xarray.Dataset + movement dataset containing the data to be converted to NWB. + subject : str + Name of the subject (individual) to be converted. + pose_estimation_series_kwargs : dict, optional + PoseEstimationSeries keyword arguments. See ndx_pose, by default None + pose_estimation_kwargs : dict, optional + PoseEstimation keyword arguments. See ndx_pose, by default None + skeleton_kwargs : dict, optional + Skeleton keyword arguments. See ndx_pose, by default None + + Returns + ------- + pose_estimation : list[ndx_pose.PoseEstimation] + List of PoseEstimation objects + skeletons : ndx_pose.Skeletons + Skeletons object containing all skeletons + + """ + if pose_estimation_series_kwargs is None: + pose_estimation_series_kwargs = dict( + reference_frame="(0,0,0) corresponds to ...", + confidence_definition=None, + conversion=1.0, + resolution=-1.0, + offset=0.0, + starting_time=None, + comments="no comments", + description="no description", + control=None, + control_description=None, + ) + + if skeleton_kwargs is None: + skeleton_kwargs = dict(edges=None) + + if pose_estimation_kwargs is None: + pose_estimation_kwargs = dict( + original_videos=None, + labeled_videos=None, + dimensions=None, + devices=None, + scorer=None, + source_software_version=None, + ) + + pose_estimation_series = [] + + for keypoint in ds.keypoints.to_numpy(): + pose_estimation_series.append( + ndx_pose.PoseEstimationSeries( + name=keypoint, + data=ds.sel(keypoints=keypoint).position.to_numpy(), + confidence=ds.sel(keypoints=keypoint).confidence.to_numpy(), + unit="pixels", + timestamps=ds.sel(keypoints=keypoint).time.to_numpy(), + **pose_estimation_series_kwargs, + ) + ) + + skeleton_list = [ + ndx_pose.Skeleton( + name=f"{subject}_skeleton", + nodes=ds.keypoints.to_numpy().tolist(), + **skeleton_kwargs, + ) + ] + + bodyparts_str = ", ".join(ds.keypoints.to_numpy().tolist()) + description = ( + f"Estimated positions of {bodyparts_str} of" + f"{subject} using {ds.source_software}." + ) + + pose_estimation = [ + ndx_pose.PoseEstimation( + name="PoseEstimation", + pose_estimation_series=pose_estimation_series, + description=description, + source_software=ds.source_software, + skeleton=skeleton_list[-1], + **pose_estimation_kwargs, + ) + ] + + skeletons = ndx_pose.Skeletons(skeletons=skeleton_list) + + return pose_estimation, skeletons + + +def add_movement_dataset_to_nwb( + nwbfiles: list[pynwb.NWBFile] | pynwb.NWBFile, + movement_dataset: xr.Dataset, + pose_estimation_series_kwargs: dict | None = None, + pose_estimation_kwargs: dict | None = None, + skeletons_kwargs: dict | None = None, +) -> None: + """Add pose estimation data to NWB files for each individual. + + Parameters + ---------- + nwbfiles : list[pynwb.NWBFile] | pynwb.NWBFile + NWBFile object(s) to which the data will be added. + movement_dataset : xr.Dataset + ``movement`` dataset containing the data to be converted to NWB. + pose_estimation_series_kwargs : dict, optional + PoseEstimationSeries keyword arguments. See ndx_pose, by default None + pose_estimation_kwargs : dict, optional + PoseEstimation keyword arguments. See ndx_pose, by default None + skeletons_kwargs : dict, optional + Skeleton keyword arguments. See ndx_pose, by default None + + Raises + ------ + ValueError + If the number of NWBFiles is not equal to the number of individuals + in the dataset. + + """ + if isinstance(nwbfiles, pynwb.NWBFile): + nwbfiles = [nwbfiles] + + if len(nwbfiles) != len(movement_dataset.individuals): + raise log_error( + ValueError, + "Number of NWBFiles must be equal to the number of individuals. " + "NWB requires one file per individual.", + ) + + for nwbfile, subject in zip( + nwbfiles, movement_dataset.individuals.to_numpy(), strict=False + ): + pose_estimation, skeletons = _create_pose_and_skeleton_objects( + movement_dataset.sel(individuals=subject), + subject, + pose_estimation_series_kwargs, + pose_estimation_kwargs, + skeletons_kwargs, + ) + try: + behavior_pm = nwbfile.create_processing_module( + name="behavior", + description="processed behavioral data", + ) + except ValueError: + print("Behavior processing module already exists. Skipping...") + behavior_pm = nwbfile.processing["behavior"] + + try: + behavior_pm.add(skeletons) + except ValueError: + print("Skeletons already exists. Skipping...") + try: + behavior_pm.add(pose_estimation) + except ValueError: + print("PoseEstimation already exists. Skipping...") + + +def _convert_pose_estimation_series( + pose_estimation_series: ndx_pose.PoseEstimationSeries, + keypoint: str, + subject_name: str, + source_software: str, + source_file: str | None = None, +) -> xr.Dataset: + """Convert to single-keypoint, single-individual ``movement`` dataset. + + Parameters + ---------- + pose_estimation_series : ndx_pose.PoseEstimationSeries + PoseEstimationSeries NWB object to be converted. + keypoint : str + Name of the keypoint - body part. + subject_name : str + Name of the subject (individual). + source_software : str + Name of the software used to estimate the pose. + source_file : Optional[str], optional + File from which the data was extracted, by default None + + Returns + ------- + movement_dataset : xr.Dataset + ``movement`` compatible dataset containing the pose estimation data. + + """ + attrs = { + "fps": np.nanmedian(1 / np.diff(pose_estimation_series.timestamps)), + "time_units": pose_estimation_series.timestamps_unit, + "source_software": source_software, + "source_file": source_file, + } + n_space_dims = pose_estimation_series.data.shape[1] + space_dims = ["x", "y", "z"] + + position_array = np.asarray(pose_estimation_series.data)[ + :, np.newaxis, np.newaxis, : + ] + + if getattr(pose_estimation_series, "confidence", None) is None: + pose_estimation_series.confidence = np.full( + pose_estimation_series.data.shape[0], np.nan + ) + else: + confidence_array = np.asarray(pose_estimation_series.confidence)[ + :, np.newaxis, np.newaxis + ] + + return xr.Dataset( + data_vars={ + "position": ( + ["time", "individuals", "keypoints", "space"], + position_array, + ), + "confidence": ( + ["time", "individuals", "keypoints"], + confidence_array, + ), + }, + coords={ + "time": pose_estimation_series.timestamps, + "individuals": [subject_name], + "keypoints": [keypoint], + "space": space_dims[:n_space_dims], + }, + attrs=attrs, + ) + + +def convert_nwb_to_movement( + nwb_filepaths: str | list[str] | list[Path], +) -> xr.Dataset: + """Convert a list of NWB files to a single ``movement`` dataset. + + Parameters + ---------- + nwb_filepaths : str | Path | list[str] | list[Path] + List of paths to NWB files to be converted. + + Returns + ------- + movement_ds : xr.Dataset + ``movement`` dataset containing the pose estimation data. + + """ + if isinstance(nwb_filepaths, str | Path): + nwb_filepaths = [nwb_filepaths] + + datasets = [] + for path in nwb_filepaths: + with pynwb.NWBHDF5IO(path, mode="r") as io: + nwbfile = io.read() + pose_estimation = nwbfile.processing["behavior"]["PoseEstimation"] + source_software = pose_estimation.fields["source_software"] + pose_estimation_series = pose_estimation.fields[ + "pose_estimation_series" + ] + + for keypoint, pes in pose_estimation_series.items(): + datasets.append( + _convert_pose_estimation_series( + pes, + keypoint, + subject_name=nwbfile.identifier, + source_software=source_software, + source_file=None, + ) + ) + + return xr.merge(datasets) diff --git a/pyproject.toml b/pyproject.toml index 27348c291..1d227e01d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,8 @@ dependencies = [ "sleap-io", "xarray[accel,viz]", "PyYAML", + "pynwb", + "ndx-pose>=0.2", ] classifiers = [ diff --git a/tests/test_unit/test_nwb.py b/tests/test_unit/test_nwb.py new file mode 100644 index 000000000..060c2e992 --- /dev/null +++ b/tests/test_unit/test_nwb.py @@ -0,0 +1,253 @@ +import datetime + +import ndx_pose +import numpy as np +from ndx_pose import PoseEstimation, PoseEstimationSeries, Skeleton, Skeletons +from pynwb import NWBHDF5IO, NWBFile +from pynwb.file import Subject + +from movement import sample_data +from movement.io.nwb import ( + _convert_pose_estimation_series, + _create_pose_and_skeleton_objects, + add_movement_dataset_to_nwb, + convert_nwb_to_movement, +) + + +def test_create_pose_and_skeleton_objects(): + # Create a sample dataset + ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv") + + # Call the function + pose_estimation, skeletons = _create_pose_and_skeleton_objects( + ds.sel(individuals="individual1"), + subject="individual1", + pose_estimation_series_kwargs=None, + pose_estimation_kwargs=None, + skeleton_kwargs=None, + ) + + # Assert the output types + assert isinstance(pose_estimation, list) + assert isinstance(skeletons, ndx_pose.Skeletons) + + # Assert the length of pose_estimation list + assert len(pose_estimation) == 1 + + # Assert the length of pose_estimation_series list + assert len(pose_estimation[0].pose_estimation_series) == 12 + + # Assert the name of the first PoseEstimationSeries + assert "snout" in pose_estimation[0].pose_estimation_series + + # Assert the name of the Skeleton + assert "individual1_skeleton" in skeletons.skeletons + + +def create_test_pose_estimation_series( + n_time=100, n_dims=2, keypoint="front_left_paw" +): + data = np.random.rand( + n_time, n_dims + ) # num_frames x (x, y) but can be (x, y, z) + timestamps = np.linspace(0, 10, num=n_time) # a timestamp for every frame + confidence = np.ones((n_time,)) # a confidence value for every frame + reference_frame = "(0,0,0) corresponds to ..." + confidence_definition = "Softmax output of the deep neural network." + + return PoseEstimationSeries( + name=keypoint, + description="Marker placed around fingers of front left paw.", + data=data, + unit="pixels", + reference_frame=reference_frame, + timestamps=timestamps, + confidence=confidence, + confidence_definition=confidence_definition, + ) + + +def test__convert_pose_estimation_series(): + # Create a sample PoseEstimationSeries object + pose_estimation_series = create_test_pose_estimation_series( + n_time=100, n_dims=2, keypoint="front_left_paw" + ) + + # Call the function + movement_dataset = _convert_pose_estimation_series( + pose_estimation_series, + keypoint="leftear", + subject_name="individual1", + source_software="software1", + source_file="file1", + ) + + # Assert the dimensions of the movement dataset + assert movement_dataset.sizes == { + "time": 100, + "individuals": 1, + "keypoints": 1, + "space": 2, + } + + # Assert the values of the position variable + np.testing.assert_array_equal( + movement_dataset["position"].values, + pose_estimation_series.data[:, np.newaxis, np.newaxis, :], + ) + + # Assert the values of the confidence variable + np.testing.assert_array_equal( + movement_dataset["confidence"].values, + pose_estimation_series.confidence[:, np.newaxis, np.newaxis], + ) + + # Assert the attributes of the movement dataset + assert movement_dataset.attrs == { + "fps": np.nanmedian(1 / np.diff(pose_estimation_series.timestamps)), + "time_units": pose_estimation_series.timestamps_unit, + "source_software": "software1", + "source_file": "file1", + } + pose_estimation_series = create_test_pose_estimation_series( + n_time=50, n_dims=3, keypoint="front_left_paw" + ) + + # Assert the dimensions of the movement dataset + assert movement_dataset.sizes == { + "time": 50, + "individuals": 1, + "keypoints": 1, + "space": 3, + } + + +def test_add_movement_dataset_to_nwb_single_file(): + ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv") + session_start_time = datetime.datetime.now(datetime.timezone.utc) + nwbfile_individual1 = NWBFile( + session_description="session_description", + identifier="individual1", + session_start_time=session_start_time, + ) + add_movement_dataset_to_nwb( + nwbfile_individual1, ds.sel(individuals=["individual1"]) + ) + assert ( + "PoseEstimation" + in nwbfile_individual1.processing["behavior"].data_interfaces + ) + assert ( + "Skeletons" + in nwbfile_individual1.processing["behavior"].data_interfaces + ) + + +def test_add_movement_dataset_to_nwb_multiple_files(): + ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv") + session_start_time = datetime.datetime.now(datetime.timezone.utc) + nwbfile_individual1 = NWBFile( + session_description="session_description", + identifier="individual1", + session_start_time=session_start_time, + ) + nwbfile_individual2 = NWBFile( + session_description="session_description", + identifier="individual2", + session_start_time=session_start_time, + ) + + nwbfiles = [nwbfile_individual1, nwbfile_individual2] + add_movement_dataset_to_nwb(nwbfiles, ds) + + +def create_test_pose_nwb(identifier="subject1", write_to_disk=False): + # initialize an NWBFile object + nwbfile = NWBFile( + session_description="session_description", + identifier=identifier, + session_start_time=datetime.datetime.now(datetime.timezone.utc), + ) + + # add a subject to the NWB file + subject = Subject(subject_id=identifier, species="Mus musculus") + nwbfile.subject = subject + + skeleton = Skeleton( + name="subject1_skeleton", + nodes=["front_left_paw", "body", "front_right_paw"], + edges=np.array([[0, 1], [1, 2]], dtype="uint8"), + subject=subject, + ) + + skeletons = Skeletons(skeletons=[skeleton]) + + # create a device for the camera + camera1 = nwbfile.create_device( + name="camera1", + description="camera for recording behavior", + manufacturer="my manufacturer", + ) + + n_time = 100 + n_dims = 2 # 2D data + front_left_paw = create_test_pose_estimation_series( + n_time=n_time, n_dims=n_dims, keypoint="front_left_paw" + ) + + body = create_test_pose_estimation_series( + n_time=n_time, n_dims=n_dims, keypoint="body" + ) + front_right_paw = create_test_pose_estimation_series( + n_time=n_time, n_dims=n_dims, keypoint="front_right_paw" + ) + + # store all PoseEstimationSeries in a list + pose_estimation_series = [front_left_paw, body, front_right_paw] + + pose_estimation = PoseEstimation( + name="PoseEstimation", + pose_estimation_series=pose_estimation_series, + description=( + "Estimated positions of front paws" "of subject1 using DeepLabCut." + ), + original_videos=["path/to/camera1.mp4"], + labeled_videos=["path/to/camera1_labeled.mp4"], + dimensions=np.array( + [[640, 480]], dtype="uint16" + ), # pixel dimensions of the video + devices=[camera1], + scorer="DLC_resnet50_openfieldOct30shuffle1_1600", + source_software="DeepLabCut", + source_software_version="2.3.8", + skeleton=skeleton, # link to the skeleton object + ) + + behavior_pm = nwbfile.create_processing_module( + name="behavior", + description="processed behavioral data", + ) + behavior_pm.add(skeletons) + behavior_pm.add(pose_estimation) + + # write the NWBFile to disk + if write_to_disk: + path = "test_pose.nwb" + with NWBHDF5IO(path, mode="w") as io: + io.write(nwbfile) + else: + return nwbfile + + +def test_convert_nwb_to_movement(): + create_test_pose_nwb(write_to_disk=True) + nwb_filepaths = ["test_pose.nwb"] + movement_dataset = convert_nwb_to_movement(nwb_filepaths) + + assert movement_dataset.sizes == { + "time": 100, + "individuals": 1, + "keypoints": 3, + "space": 2, + }