Skip to content

Commit

Permalink
average keypoint distance & videos
Browse files Browse the repository at this point in the history
  • Loading branch information
tlpss committed Apr 2, 2024
1 parent c1ec45e commit 576b82f
Show file tree
Hide file tree
Showing 5 changed files with 376 additions and 0 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ We have generated data for *almost-flattened* T-shirts, shorts and towels. The
In the paper, we test the efficacy of the synthetic data by training keypoint detectors and evaulating their performance on the [aRTF Clothes dataset](https://github.com/tlpss/aRTF-Clothes-dataset).
The repo contains code to reproduce all experiments from the paper. A number of trained keypoint detectors are also made available. Informations on how to use these checkpoints can be found here [here](state-estimation/Readme.md).


Live interactions with these checkpoints can be seen here:
<iframe width="560" height="315" src="https://www.youtube.com/embed/Mlwg_qPxr78?si=VntdFciJBo_Cr3j-" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" referrerpolicy="strict-origin-when-cross-origin" allowfullscreen></iframe>

In addition we show how the keypoints can be used for folding T-shirts below.
In this video we take a single image using the ego-centric camera and predict keypoints on that image. Based on these keypoints, a scripted sequence of fold motions is executed that allow us to fold T-shirts.

<iframe width="560" height="315" src="https://www.youtube.com/embed/bqnQ4iLnp20?si=jCaGCxDZSOYlDuo1" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" referrerpolicy="strict-origin-when-cross-origin" allowfullscreen></iframe>

## High-level overview of the codebase

```
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
**.json
199 changes: 199 additions & 0 deletions state-estimation/state_estimation/keypoint_detection/explore_akd.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,37 @@
SYNTHETIC_TOWELS_CHECKPOINT = "tlips/cloth-keypoints-paper/model-la717c59:v0"
FINETUNED_SYNTHETIC_TOWELS_CHECKPOINT = "tlips/cloth-keypoints-paper/model-lcktxm1b:v0"
REAL_TOWELS_CHECKPOINT = "tlips/cloth-keypoints-paper/model-uc3ukfa0:v0"

TSHIRTS_MESH_CLOTH3D_CHECKPOINT = "tlips/cloth-keypoints-paper/model-e61sefpk:v0"
TSHIRTS_MESH_SINGLE_UNDEFORMED_CHECKPOINT = "tlips/cloth-keypoints-paper/model-7gcchgwg:v0"
TSHIRTS_MESH_SINGLE_CHECKPOINT = "tlips/cloth-keypoints-paper/model-4e73vowg:v0"

TSHIRTS_MATERIAL_RANDOM_CHECKPOINT = "tlips/cloth-keypoints-paper/model-dza0nezg:v0"
TSHIRTS_MATERIAL_HSV_CHECKPOINT = "tlips/cloth-keypoints-paper/model-lvoyd3vn:v0"
TSHIRTS_MATERIAL_TAILORED_CHECKPOINT = "tlips/cloth-keypoints-paper/model-ulxsei05:v0"

TSHIRTS_SIM_TO_SIM = "tlips/cloth-keypoints-paper/model-7xt2dg7l:v0"
TOWELS_SIM_TO_SIM = "tlips/cloth-keypoints-paper/model-oa457msf:v0"
SHORTS_SIM_TO_SIM = "tlips/cloth-keypoints-paper/model-35zjw47l:v0"


ARTIFACT_DICT = {
"synthetic_tshirts": SYNTHETIC_TSHIRTS_CHECKPOINT,
"finetuned_synthetic_tshirts": FINETUNED_SYNTHETIC_TSHIRTS_CHECKPOINT,
"real_tshirts": REAL_TSHIRTS_CHECKPOINT,
"synthetic_shorts": SYNTHETIC_SHORTS_CHECKPOINT,
"finetuned_synthetic_shorts": FINETUNED_SYNTHETIC_SHORTS_CHECKPOINT,
"real_shorts": REAL_SHORTS_CHECKPOINT,
"synthetic_towels": SYNTHETIC_TOWELS_CHECKPOINT,
"finetuned_synthetic_towels": FINETUNED_SYNTHETIC_TOWELS_CHECKPOINT,
"real_towels": REAL_TOWELS_CHECKPOINT,
"tshirts_mesh_cloth3d": TSHIRTS_MESH_CLOTH3D_CHECKPOINT,
"tshirts_mesh_single_undeformed": TSHIRTS_MESH_SINGLE_UNDEFORMED_CHECKPOINT,
"tshirts_mesh_single": TSHIRTS_MESH_SINGLE_CHECKPOINT,
"tshirts_material_random": TSHIRTS_MATERIAL_RANDOM_CHECKPOINT,
"tshirts_material_hsv": TSHIRTS_MATERIAL_HSV_CHECKPOINT,
"tshirts_material_tailored": TSHIRTS_MATERIAL_TAILORED_CHECKPOINT,
"tshirts_sim_to_sim": TSHIRTS_SIM_TO_SIM,
"towels_sim_to_sim": TOWELS_SIM_TO_SIM,
"shorts_sim_to_sim": SHORTS_SIM_TO_SIM,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch
from keypoint_detection.data.coco_dataset import COCOKeypointsDataset
from keypoint_detection.models.detector import KeypointDetector
from keypoint_detection.utils.heatmap import get_keypoints_from_heatmap_batch_maxpool
from keypoint_detection.utils.load_checkpoints import get_model_from_wandb_checkpoint
from torch.utils.data import DataLoader
from tqdm import tqdm


def calculate_average_error_for_dataset(
model: KeypointDetector, dataset_json_path, channel_config: list[list[str]], detect_only_visible_keypoints
):
dataset = COCOKeypointsDataset(
dataset_json_path,
keypoint_channel_configuration=channel_config,
detect_only_visible_keypoints=detect_only_visible_keypoints,
)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

model.eval()
model.cuda()

errors = [[] for _ in range(len(channel_config))]
for image, keypoints in tqdm(dataloader):
image = image.cuda()
heatmaps = model(image)
keypoints = keypoints
predicted_keypoints = get_keypoints_from_heatmap_batch_maxpool(heatmaps, max_keypoints=1)[0]
for i in range(len(channel_config)):
if len(predicted_keypoints[i]) == 0:
print("no keypoints found")
continue
if len(keypoints[i]) == 0:
# print("no GT keypoints found")
continue
kp = torch.tensor(predicted_keypoints[i][0], dtype=torch.float32)
gt_kp = torch.tensor(keypoints[i][0], dtype=torch.float32)
l2_error = torch.norm(kp - gt_kp)
errors[i].append(l2_error.item())

average_errors = [sum(errors[i]) / len(errors[i]) for i in range(len(channel_config))]
for i in range(len(channel_config)):
print(f"Average error for channel {channel_config[i]}: {average_errors[i]}")
print(f"Average error: {sum(average_errors)/len(average_errors)}")

mae_dict = {}
full_dict = {}
for i in range(len(channel_config)):
channel_name = "" + "-".join(channel_config[i])
mae_dict[channel_name] = average_errors[i]
full_dict[channel_name] = errors[i]
mae_dict["average"] = sum(average_errors) / len(average_errors)

return mae_dict, full_dict


if __name__ == "__main__":
from state_estimation.keypoint_detection.common import (
SHORTS_CHANNEL_CONFIG,
TOWEL_CHANNEL_CONFIG,
TSHIRT_CHANNEL_CONFIG,
data_dir,
)
from state_estimation.keypoint_detection.final_checkpoints import ARTIFACT_DICT
from state_estimation.keypoint_detection.real_baselines import (
ARTF_SHORTS_TEST_PATH,
ARTF_TOWEL_TEST_PATH,
ARTF_TSHIRT_TEST_PATH,
)

error_dict = {}
for key, value in ARTIFACT_DICT.items():
if "tshirt" in key:
wandb_checkpoint = value
if "sim" in key:
dataset = (
data_dir
/ "synthetic-data"
/ "TSHIRT"
/ "single-layer-random-material-10K"
/ "annotations_val.json"
)
else:
dataset = ARTF_TSHIRT_TEST_PATH
keypoints = TSHIRT_CHANNEL_CONFIG.split(":")
keypoints = [channel.split(",") for channel in keypoints]
elif "towel" in key:
wandb_checkpoint = value
if "sim" in key:
dataset = (
data_dir / "synthetic-data" / "TOWEL" / "single-layer-random-material-10K" / "annotations_val.json"
)
else:
dataset = ARTF_TOWEL_TEST_PATH
keypoints = TOWEL_CHANNEL_CONFIG.split(":")
keypoints = [channel.split(",") for channel in keypoints]
elif "shorts" in key:
wandb_checkpoint = value
if "sim" in key:
dataset = (
data_dir
/ "synthetic-data"
/ "SHORTS"
/ "single-layer-random-material-10K"
/ "annotations_val.json"
)
else:
dataset = ARTF_SHORTS_TEST_PATH
keypoints = SHORTS_CHANNEL_CONFIG.split(":")
keypoints = [channel.split(",") for channel in keypoints]
else:
raise ValueError("Unknown artifact key")

print(f"Calculating average error for {key}")
print(f"dataset = {dataset}")
print(f"keypoints = {keypoints}")

model = get_model_from_wandb_checkpoint(wandb_checkpoint).cuda()
avg_errors, d = calculate_average_error_for_dataset(
model, dataset, keypoints, detect_only_visible_keypoints=True
)
error_dict[key] = avg_errors

# save dict as json
import json
import pathlib

file_path = pathlib.Path(__file__).parent
# with open(file_path / "average_keypoint_distances.json", "w") as f:
# json.dump(error_dict, f, indent=4)

with open(file_path / "akd" / f"{key}.json", "w") as f:
json.dump(d, f, indent=4)

0 comments on commit 576b82f

Please sign in to comment.