Skip to content

Commit

Permalink
Merge pull request #5 from PPierzc/main
Browse files Browse the repository at this point in the history
ArXiv release
  • Loading branch information
PPierzc authored Oct 20, 2022
2 parents 6634db8 + d19a7db commit 08105b3
Show file tree
Hide file tree
Showing 77 changed files with 463 additions and 341 deletions.
2 changes: 2 additions & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[settings]
profile=black
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ docker pull sinzlab/pytorch:v3.9-torch1.9.0-cuda11.1-dj0.12.7
5. You can now open JupyterLab in your browser at [`http://localhost:10101`](http://localhost:10101).

#### Available Models
| Model Name | description | Artifact path | Import Code |
| --- |---------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------|----------------------------------|
| cGNF Human 3.6m | Model trained on the Human 3.6M dataset with MPII input keypoints. | ```ppierzc/cgnf/cgnf_human36m:best``` | ```from propose.models.flows import CondGraphFlow``` |
| HRNet | Instance of the [official](https://github.com/leoxiaobin/deep-high-resolution-net.pytorch) HRNet model trained on the MPII dataset with w32 and 256x256 | ```ppierzc/cgnf/hrnet:v0``` | ```from propose.models.detectors import HRNet``` |
| Model Name | description | Artifact path | Import Code |
| --- |---------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------|----------------------------------|
| cGNF Human 3.6m | Model trained on the Human 3.6M dataset with MPII input keypoints. | ```ppierzc/propose_human36m/mpii-prod:best``` | ```from propose.models.flows import CondGraphFlow``` |
| HRNet | Instance of the [official](https://github.com/leoxiaobin/deep-high-resolution-net.pytorch) HRNet model trained on the MPII dataset with w32 and 256x256 | ```ppierzc/cgnf/hrnet:v0``` | ```from propose.models.detectors import HRNet``` |

### Run Tests
To run the tests, from the root directory call:
Expand Down
32 changes: 18 additions & 14 deletions notebooks/demo/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
"execution_count": 3,
"outputs": [],
"source": [
"mocap_path = f'../../data/rat7m/mocap/mocap-s4-d1.mat'\n",
"vid_path = f'../../data/rat7m/movies/s4-d1/s4-d1-camera4-0.mp4'\n",
"mocap_path = f\"../../data/rat7m/mocap/mocap-s4-d1.mat\"\n",
"vid_path = f\"../../data/rat7m/movies/s4-d1/s4-d1-camera4-0.mp4\"\n",
"\n",
"vid = imageio.get_reader(vid_path)"
],
Expand Down Expand Up @@ -77,7 +77,7 @@
"\n",
"pose = pose[mask]\n",
"\n",
"np.array(pose._edge('HeadF', 'HeadB'))"
"np.array(pose._edge(\"HeadF\", \"HeadB\"))"
],
"metadata": {
"collapsed": false,
Expand All @@ -103,9 +103,9 @@
],
"source": [
"fig = plt.figure(figsize=(10, 10))\n",
"ax1 = fig.add_subplot(1, 1, 1, projection='3d')\n",
"ax1 = fig.add_subplot(1, 1, 1, projection=\"3d\")\n",
"ax1.get_proj = lambda: np.dot(Axes3D.get_proj(ax1), np.diag([1, 1, 0.75, 1]))\n",
"ax1.view_init(30, 30)\n",
"ax1.view_init(30, 30)\n",
"ax1.set_xlim(-400, -100)\n",
"ax1.set_ylim(-300, 0)\n",
"ax1.set_zlim(0, 100)\n",
Expand Down Expand Up @@ -136,7 +136,7 @@
}
],
"source": [
"camera = cameras['Camera4']\n",
"camera = cameras[\"Camera4\"]\n",
"pose2D = Rat7mPose(camera.proj2D(pose))\n",
"\n",
"frame_idx = camera.frames.squeeze()[mask][0]\n",
Expand Down Expand Up @@ -169,7 +169,7 @@
],
"source": [
"cameras = load_cameras(mocap_path)\n",
"camera = cameras['Camera4']\n",
"camera = cameras[\"Camera4\"]\n",
"camera.frames = camera.frames.squeeze()[mask]\n",
"pose_idx = 0\n",
"\n",
Expand All @@ -183,23 +183,24 @@
"\n",
"fig = plt.figure(figsize=(20, 10))\n",
"\n",
"ax1 = fig.add_subplot(1, 2, 1, projection='3d')\n",
"ax1 = fig.add_subplot(1, 2, 1, projection=\"3d\")\n",
"ax1.get_proj = lambda: np.dot(Axes3D.get_proj(ax1), np.diag([1, 1, 0.75, 1]))\n",
"ax1.view_init(30, 30)\n",
"ax1.view_init(30, 30)\n",
"ax1.set_xlim(-400, -100)\n",
"ax1.set_ylim(-300, 0)\n",
"ax1.set_zlim(0, 100)\n",
"\n",
"ax2 = fig.add_subplot(1, 2, 2)\n",
"ax2.set_title('Camera 4')\n",
"plt.axis('off')\n",
"ax2.set_title(\"Camera 4\")\n",
"plt.axis(\"off\")\n",
"\n",
"img = ax2.imshow(im)\n",
"animate1 = pose.animate(ax1)\n",
"animate2 = pose2D.animate(ax2)\n",
"\n",
"plt.close(fig)\n",
"\n",
"\n",
"def animate(i):\n",
" frame_idx = camera.frames.squeeze()[pose_idx + i]\n",
" im = vid.get_data(frame_idx)\n",
Expand All @@ -209,14 +210,17 @@
" animate1(i)\n",
" animate2(i)\n",
"\n",
"\n",
"ani = animation.FuncAnimation(fig, animate, frames=100)\n",
"\n",
"Writer = animation.writers['ffmpeg']\n",
"writer = Writer(fps=30, metadata=dict(artist='Me'), bitrate=-1)\n",
"Writer = animation.writers[\"ffmpeg\"]\n",
"writer = Writer(fps=30, metadata=dict(artist=\"Me\"), bitrate=-1)\n",
"\n",
"pbar = tqdm(total=100, position=0)\n",
"\n",
"ani.save('walk_cam_4_sub.mp4', writer=writer, progress_callback=lambda i, n: pbar.update(1))"
"ani.save(\n",
" \"walk_cam_4_sub.mp4\", writer=writer, progress_callback=lambda i, n: pbar.update(1)\n",
")"
],
"metadata": {
"collapsed": false,
Expand Down
22 changes: 11 additions & 11 deletions notebooks/demo/load_human36m.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,28 +42,28 @@
"path = \"/Users/paulpierzchlewicz/PycharmProjects/propose/data/human36m/Directions.60457274.cdf\"\n",
"poses = load_poses(path)\n",
"\n",
"poses /= poses.std() # Normalize poses\n",
"poses /= poses.std() # Normalize poses\n",
"\n",
"poses = Human36mPose(poses)\n",
"\n",
"pose = poses[200]\n",
"\n",
"plt.style.use('default')\n",
"plt.style.use(\"default\")\n",
"fig = plt.figure()\n",
"ax = fig.add_subplot(111, projection='3d')\n",
"ax.view_init(elev=15., azim=120)\n",
"poses[395].plot(ax=ax, alpha=.1)\n",
"poses[396].plot(ax=ax, alpha=.2)\n",
"poses[397].plot(ax=ax, alpha=.3)\n",
"poses[398].plot(ax=ax, alpha=.4)\n",
"poses[399].plot(ax=ax, alpha=.5)\n",
"ax = fig.add_subplot(111, projection=\"3d\")\n",
"ax.view_init(elev=15.0, azim=120)\n",
"poses[395].plot(ax=ax, alpha=0.1)\n",
"poses[396].plot(ax=ax, alpha=0.2)\n",
"poses[397].plot(ax=ax, alpha=0.3)\n",
"poses[398].plot(ax=ax, alpha=0.4)\n",
"poses[399].plot(ax=ax, alpha=0.5)\n",
"poses[400].plot(ax=ax, alpha=1)\n",
"\n",
"ax.set_xlim(2, -2)\n",
"ax.set_ylim(2, -2)\n",
"ax.set_zlim(-2, 2)\n",
"\n",
"ax.xaxis.pane.fill = False # Left pane\n",
"ax.xaxis.pane.fill = False # Left pane\n",
"ax.yaxis.pane.fill = False\n",
"ax.zaxis.pane.fill = False\n",
"ax.grid(False)\n",
Expand All @@ -88,7 +88,7 @@
"ax.set_yticks([])\n",
"ax.set_zticks([])\n",
"\n",
"plt.savefig('./human36m_pose.png', dpi=300)\n",
"plt.savefig(\"./human36m_pose.png\", dpi=300)\n",
"\n",
"plt.show()"
],
Expand Down
44 changes: 24 additions & 20 deletions notebooks/demo/load_rat7m_dataset_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
"execution_count": 7,
"outputs": [],
"source": [
"dirname = '/Users/paulpierzchlewicz/PycharmProjects/propose/data/rat7m' # Choose this such that it points to your dataset\n",
"data_key = 's4-d1'"
"dirname = \"/Users/paulpierzchlewicz/PycharmProjects/propose/data/rat7m\" # Choose this such that it points to your dataset\n",
"data_key = \"s4-d1\""
],
"metadata": {
"collapsed": false,
Expand All @@ -61,17 +61,21 @@
"execution_count": 18,
"outputs": [],
"source": [
"dataset = Rat7mDataset(dirname=dirname, data_key=data_key, transforms=[\n",
" tr.SwitchArmsElbows(),\n",
" tr.CropImageToPose(),\n",
" tr.RotatePoseToCamera(),\n",
" tr.CenterPose(),\n",
" tr.ScalePose(scale=0.03),\n",
" ScaleInputs(scale=0.1, multichannel=True, anti_aliasing=True),\n",
" tr.NormaliseImageScale(),\n",
" tr.ToGraph(),\n",
" ToTensor()\n",
"])"
"dataset = Rat7mDataset(\n",
" dirname=dirname,\n",
" data_key=data_key,\n",
" transforms=[\n",
" tr.SwitchArmsElbows(),\n",
" tr.CropImageToPose(),\n",
" tr.RotatePoseToCamera(),\n",
" tr.CenterPose(),\n",
" tr.ScalePose(scale=0.03),\n",
" ScaleInputs(scale=0.1, multichannel=True, anti_aliasing=True),\n",
" tr.NormaliseImageScale(),\n",
" tr.ToGraph(),\n",
" ToTensor(),\n",
" ],\n",
")"
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -103,25 +107,25 @@
"pose = Rat7mPose(pose_matrix.numpy())\n",
"image = res.image\n",
"\n",
"plt.style.use('default')\n",
"plt.style.use(\"default\")\n",
"fig = plt.figure(figsize=(20, 10))\n",
"ax1 = fig.add_subplot(1, 2, 1)\n",
"\n",
"ax1.imshow(image)\n",
"\n",
"ax2 = fig.add_subplot(1, 2, 2, projection='3d')\n",
"ax2 = fig.add_subplot(1, 2, 2, projection=\"3d\")\n",
"ax2.get_proj = lambda: np.dot(Axes3D.get_proj(ax2), np.diag([1, 1, 0.75, 1]))\n",
"ax2.view_init(45, 90)\n",
"ax2.view_init(45, 90)\n",
"ax2.set_xlim(3, -3)\n",
"ax2.set_ylim(3, -3)\n",
"ax2.set_zlim(-1, 1)\n",
"ax2.set_xlabel('x')\n",
"ax2.set_ylabel('y')\n",
"ax2.set_zlabel('z')\n",
"ax2.set_xlabel(\"x\")\n",
"ax2.set_ylabel(\"y\")\n",
"ax2.set_zlabel(\"z\")\n",
"ax2.set_zticks([])\n",
"\n",
"pose.plot(ax=ax2)\n",
"plt.show()\n"
"plt.show()"
],
"metadata": {
"collapsed": false,
Expand Down
4 changes: 2 additions & 2 deletions notebooks/demo/static_loader_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"execution_count": 18,
"outputs": [],
"source": [
"dirname = '/Users/paulpierzchlewicz/PycharmProjects/propose/data/rat7m/s4-d1' # Choose this such that it points to your dataset"
"dirname = \"/Users/paulpierzchlewicz/PycharmProjects/propose/data/rat7m/s4-d1\" # Choose this such that it points to your dataset"
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -66,7 +66,7 @@
}
],
"source": [
"for i in dataloaders['train']:\n",
"for i in dataloaders[\"train\"]:\n",
" print(i.pose_matrix.shape)\n",
" print(i.adjacency_matrix.shape)\n",
" print(i.image.shape)\n",
Expand Down
22 changes: 12 additions & 10 deletions notebooks/preprocess_rat7m.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
"execution_count": 2,
"outputs": [],
"source": [
"dirname = '/Users/paulpierzchlewicz/PycharmProjects/propose/data/rat7m' # Choose this such that it points to your dataset\n",
"data_key = 's4-d1'\n",
"mocap_path = f'{dirname}/mocap/mocap-{data_key}.mat'"
"dirname = \"/Users/paulpierzchlewicz/PycharmProjects/propose/data/rat7m\" # Choose this such that it points to your dataset\n",
"data_key = \"s4-d1\"\n",
"mocap_path = f\"{dirname}/mocap/mocap-{data_key}.mat\""
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -159,10 +159,10 @@
"source": [
"from pathlib import Path\n",
"\n",
"pose_dir = Path(f'{dirname}/{data_key}/poses')\n",
"pose_dir = Path(f\"{dirname}/{data_key}/poses\")\n",
"pose_dir.mkdir(parents=True, exist_ok=True)\n",
"\n",
"pose_path = pose_dir / f'{data_key}.npy'\n",
"pose_path = pose_dir / f\"{data_key}.npy\"\n",
"\n",
"mocap.save(pose_path)"
],
Expand All @@ -179,13 +179,14 @@
"outputs": [],
"source": [
"import pickle\n",
"camera_dir = Path(f'{dirname}/{data_key}/cameras')\n",
"\n",
"camera_dir = Path(f\"{dirname}/{data_key}/cameras\")\n",
"camera_dir.mkdir(parents=True, exist_ok=True)\n",
"\n",
"camera_path = camera_dir / f'{data_key}.pickle'\n",
"camera_path = camera_dir / f\"{data_key}.pickle\"\n",
"\n",
"with open(camera_path, 'wb') as f:\n",
" pickle.dump(cameras, f)\n"
"with open(camera_path, \"wb\") as f:\n",
" pickle.dump(cameras, f)"
],
"metadata": {
"collapsed": false,
Expand All @@ -209,7 +210,8 @@
],
"source": [
"from pathlib import PurePath\n",
"PurePath('/Users/paulpierzchlewicz/PycharmProjects/propose/data/rat7m').name"
"\n",
"PurePath(\"/Users/paulpierzchlewicz/PycharmProjects/propose/data/rat7m\").name"
],
"metadata": {
"collapsed": false,
Expand Down
4 changes: 2 additions & 2 deletions propose/cameras/Camera.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional

import numpy as np
import numpy.typing as npt

from typing import Optional

Point2D = npt.NDArray[float]
Point3D = npt.NDArray[float]

Expand Down
3 changes: 1 addition & 2 deletions propose/datasets/graph_transforms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import numpy as np
import torch
import torch.distributions as D

from torch_geometric.loader.dataloader import Collater
from torch_geometric.data import HeteroData
from torch_geometric.loader.dataloader import Collater


class ScaleGraphPose(object):
Expand Down
18 changes: 7 additions & 11 deletions propose/datasets/human36m/Human36mDataset.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
import pickle

import numpy as np

from pathlib import Path

import numpy as np
import torch
import torch.distributions as D
from torch.utils.data import Dataset
from propose.poses.human36m import Human36mPose

from torch_geometric.data import HeteroData
from torch_geometric.loader.dataloader import Collater

import torch
import torch.distributions as D

from tqdm import tqdm

from propose.poses.human36m import Human36mPose


class Human36mDataset(Dataset):
"""
Expand Down Expand Up @@ -172,7 +168,7 @@ def __init__(

for p in occlusion_fractions:
mask = ~self.occlusions[i]
mask = np.insert(mask, 9, False)
mask = np.insert(mask, 8, False)

mask[: int(p * context_edges.shape[-1])] = 0

Expand All @@ -182,7 +178,7 @@ def __init__(

if mpii:
mask = ~self.occlusions[i]
mask = np.insert(mask, 9, False)
mask = np.insert(mask, 8, False)
rand_idx = np.random.choice(
np.arange(0, len(mask)), int(len(mask) * p), replace=False
)
Expand Down
Loading

0 comments on commit 08105b3

Please sign in to comment.