Skip to content

Commit

Permalink
Merge pull request #25 from PTG-Kitware/dev/dropout_aug
Browse files Browse the repository at this point in the history
Dev/dropout aug
  • Loading branch information
cfunk1210 authored May 21, 2024
2 parents fe03abe + 1bd562b commit e642c8b
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 5 deletions.
7 changes: 7 additions & 0 deletions configs/data/all_transforms/DropoutObjects.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
DropoutObjects:
_target_: tcn_hpl.data.components.augmentations.DropoutObjects
skip_stride: 3
dropout_last: false
num_obj_classes: 42
feat_version: 2
top_k_objects: 1
5 changes: 3 additions & 2 deletions configs/experiment/r18/feat_v6.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ defaults:

# all parameters below will be merged with parameters from default configurations set above

tags: ["r18", "ms_tcn"]
tags: ["r18", "ms_tcn", "debug"]

seed: 12345

trainer:
min_epochs: 50
max_epochs: 500
log_every_n_steps: 1


model:
Expand Down Expand Up @@ -74,7 +75,7 @@ data_gen:
data:
num_classes: 6 # activities: includes background
batch_size: 512
num_workers: 0
num_workers: 12
epoch_length: 20000
window_size: 25
sample_rate: 2 # ${IMAGE_HZ} / ${OBJECT_DET_HZ}
Expand Down
2 changes: 1 addition & 1 deletion configs/model/ptg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ scheduler:
_partial_: true
mode: min
factor: 0.1
patience: 10
patience: 100000

net:
_target_: tcn_hpl.models.components.ms_tcs_net.MultiStageModel
Expand Down
129 changes: 129 additions & 0 deletions tcn_hpl/data/components/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,3 +727,132 @@ def forward(self, features):
def __repr__(self) -> str:
detail = f"(im_w={self.im_w}, im_h={self.im_h}, feat_version={self.feat_version}, top_k_objects={self.top_k_objects})"
return f"{self.__class__.__name__}{detail}"



class DropoutObjects(torch.nn.Module):
"""Drop out Objects """

def __init__(self, dropout_probablity, num_obj_classes, feat_version, top_k_objects):
"""
:param dropout_probablity: probablity that a given frame will NOT have an object
"""
super().__init__()

raise NotImplementedError

self.dropout_probablity = dropout_probablity

self.num_obj_classes = num_obj_classes
self.num_non_obj_classes = 2 # hands
self.num_good_obj_classes = self.num_obj_classes - self.num_non_obj_classes

self.top_k_objects = top_k_objects

self.feat_version = feat_version
self.opts = feature_version_to_options(self.feat_version)

self.use_activation = self.opts.get("use_activation", False)
self.use_hand_dist = self.opts.get("use_hand_dist", False)
self.use_intersection = self.opts.get("use_intersection", False)
self.use_center_dist = self.opts.get("use_center_dist", False)
self.use_joint_hand_offset = self.opts.get("use_joint_hand_offset", False)
self.use_joint_object_offset = self.opts.get("use_joint_object_offset", False)

self.obj_feature_mask = list()

ind = -1
for object_k_index in range(self.top_k_objects):
# RIGHT HAND
if self.use_activation:
ind += 1 # right hand confidence
self.obj_feature_mask.append(0)

if self.use_hand_dist:
self.obj_feature_mask += [0]*2*self.num_good_obj_classes
ind += 2*self.num_good_obj_classes

if self.use_center_dist:
# right hand - image center distance
ind += 2
self.obj_feature_mask.append(0)
self.obj_feature_mask.append(0)

# LEFT HAND
if self.use_activation:
ind += 1 # left hand confidence
self.obj_feature_mask.append(0)

if self.use_hand_dist:
# Left hand distances
self.obj_feature_mask += [0]*2*self.num_good_obj_classes
ind += 2*self.num_good_obj_classes

if self.use_center_dist:
# left hand - image center distance
ind += 2
self.obj_feature_mask.append(0)
self.obj_feature_mask.append(0)

# Right - left hand
if self.use_hand_dist:
# Right - left hand distance
ind += 2
self.obj_feature_mask.append(0)
self.obj_feature_mask.append(0)
if self.use_intersection:
ind += 1 # right - left hadn intersection
self.obj_feature_mask.append(0)
# OBJECTS
for obj_ind in range(self.num_good_obj_classes):
if self.use_activation:
ind += 1 # Object confidence
self.obj_feature_mask.append(0)

if self.use_intersection:
# obj - hands intersection
ind += 2
self.obj_feature_mask.append(0)
self.obj_feature_mask.append(0)

if self.use_center_dist:
# image center - obj distances
ind += 2
self.obj_feature_mask.append(0)
self.obj_feature_mask.append(0)

# HANDS-JOINTS
if self.use_joint_hand_offset:
# left hand - joints distances
ind += 44
self.obj_feature_mask += [1]*44

# right hand - joints distances
ind += 44
self.obj_feature_mask += [1]*44

# OBJS-JOINTS
if self.use_joint_object_offset:
self.obj_feature_mask += [0]*44*self.top_k_objects*self.num_good_obj_classes
ind += 44*self.top_k_objects*self.num_good_obj_classes

self.obj_feature_mask = torch.tensor(self.obj_feature_mask)

def forward(self, features):
num_frames = features.shape[0]
# Pick starting location of random mask
start = random.randint(0,self.skip_stride)
# Create mask (one element for each frame)
mask = torch.rand(num_frames) > self.dropout_probablity

if self.dropout_last:
mask[-1] = 0



return features

def __repr__(self) -> str:
detail = f"(im_w={self.im_w}, im_h={self.im_h}, num_obj_classes={self.num_obj_classes}, feat_version={self.feat_version}, top_k_objects={self.top_k_objects})"
return f"{self.__class__.__name__}{detail}"

9 changes: 8 additions & 1 deletion tcn_hpl/models/ptg_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def __init__(

self.net = net

# Don't popup figures
plt.ioff()

self.topic = topic

# Get Action Names
Expand Down Expand Up @@ -364,6 +367,10 @@ def training_step(
"train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True
)

self.log(
"train/lr", self.lr_schedulers().get_last_lr()[0], on_step=False, on_epoch=True, prog_bar=True
)

self.training_step_outputs_target.append(targets[:, -1])
self.training_step_outputs_source_vid.append(source_vid[:, -1])
self.training_step_outputs_source_frame.append(source_frame[:, -1])
Expand Down Expand Up @@ -410,7 +417,7 @@ def on_train_epoch_end(self) -> None:
self.train_frames[video[:-4]] = train_fns

per_video_frame_gt_preds = {}

for (gt, pred, source_vid, source_frame) in zip(
all_targets, all_preds, all_source_vids, all_source_frames
):
Expand Down
5 changes: 4 additions & 1 deletion tcn_hpl/utils/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
"""
hparams = {}

cfg = OmegaConf.to_container(object_dict["cfg"])
cfg = OmegaConf.to_container(object_dict["cfg"], resolve=True)
model = object_dict["model"]
trainer = object_dict["trainer"]

Expand Down Expand Up @@ -53,6 +53,9 @@ def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
hparams["tags"] = cfg.get("tags")
hparams["ckpt_path"] = cfg.get("ckpt_path")
hparams["seed"] = cfg.get("seed")


hparams["cfg"] = cfg

# send hparams to all loggers
for logger in trainer.loggers:
Expand Down

0 comments on commit e642c8b

Please sign in to comment.